Files
quant-factor-research/scripts/iterate_optimize.py

500 lines
18 KiB
Python

from __future__ import annotations
import argparse
import json
import os
import random
import sqlite3
from dataclasses import asdict, replace
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from qfr.strategy.etf_trend import Constraints, TrendParams, UniverseAsset, run_backtest
# Globals for multiprocessing (fork mode shares memory COW)
_G_PRICES: dict[str, pd.DataFrame] | None = None
_G_UNIVERSE: list[UniverseAsset] | None = None
_G_CONSTRAINTS: Constraints | None = None
_G_RISK_PROXY: str | None = None
_G_RATES_FALLBACK: str | None = None
def load_universe(config_path: Path) -> tuple[list[UniverseAsset], Constraints, str, str]:
conf = json.loads(config_path.read_text(encoding="utf-8"))
universe = [UniverseAsset(**a) for a in conf["assets"]]
cons = conf.get("constraints", {})
constraints = Constraints(
max_positions=int(cons.get("max_positions", 3)),
must_commodity=int(cons.get("must_include", {}).get("commodity", 0)),
must_rates=int(cons.get("must_include", {}).get("rates", 0)),
must_equity=int(cons.get("must_include", {}).get("equity", 0)),
)
risk_proxy = cons.get("risk_proxy") or (universe[0].ts_code if universe else "510300.SH")
rates_fallback = cons.get("rates_fallback", "511010.SH")
return universe, constraints, str(risk_proxy), str(rates_fallback)
def load_prices(raw_dir: Path, universe: list[UniverseAsset], start: str, end: str) -> dict[str, pd.DataFrame]:
out: dict[str, pd.DataFrame] = {}
for a in universe:
fn = raw_dir / (a.ts_code.replace(".", "") + ".parquet")
df = pd.read_parquet(fn)
df = df.copy()
df["trade_date"] = df["trade_date"].astype(str)
df = df[(df["trade_date"] >= start) & (df["trade_date"] <= end)]
out[a.ts_code] = df
return out
def perf_stats(equity: pd.Series) -> dict[str, float]:
r = equity.pct_change().dropna()
if r.empty:
return {}
ann_ret = float((equity.iloc[-1] / equity.iloc[0]) ** (252 / len(r)) - 1)
ann_vol = float(r.std(ddof=1) * (252**0.5))
dd = float((equity / equity.cummax() - 1.0).min())
sharpe = float(ann_ret / ann_vol) if ann_vol > 0 else float("nan")
return {"ann_return": ann_ret, "ann_vol": ann_vol, "max_drawdown": dd, "sharpe": sharpe}
def trades_per_year(trades: pd.DataFrame, start: str, end: str) -> float:
if trades is None or trades.empty:
return 0.0
years = max(1, (int(end[:4]) - int(start[:4]) + 1))
return float(len(trades) / years)
def load_state(path: Path) -> dict:
if path.exists():
return json.loads(path.read_text(encoding="utf-8"))
return {"best": None, "last_reported_ann_return": None, "history": []}
def save_state(path: Path, state: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(state, ensure_ascii=True, indent=2) + "\n", encoding="utf-8")
def infer_code_version(repo_dir: Path) -> str:
# Prefer git commit hash if available.
head = repo_dir / ".git" / "HEAD"
if head.exists():
try:
txt = head.read_text(encoding="utf-8").strip()
if txt.startswith("ref:"):
ref = txt.split(" ", 1)[1]
ref_path = repo_dir / ".git" / ref
if ref_path.exists():
return ref_path.read_text(encoding="utf-8").strip()
return txt
except Exception:
return "unknown"
return "nogit"
def ensure_db(db_path: Path, param_cols: list[str]) -> None:
db_path.parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(str(db_path)) as con:
con.execute("PRAGMA journal_mode=WAL")
con.execute("PRAGMA synchronous=NORMAL")
con.execute(
"""
CREATE TABLE IF NOT EXISTS trials (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
ts_utc TEXT NOT NULL,
code_version TEXT,
config_path TEXT,
start TEXT,
end TEXT,
seed INTEGER,
trial INTEGER,
jobs INTEGER,
ann_return REAL,
ann_vol REAL,
max_drawdown REAL,
sharpe REAL,
trades_per_year REAL
)
"""
)
# Add param columns if missing (structured fields)
for c in param_cols:
try:
con.execute(f"ALTER TABLE trials ADD COLUMN {c} REAL")
except sqlite3.OperationalError:
pass
def insert_rows(db_path: Path, param_cols: list[str], rows: list[dict[str, Any]]) -> None:
if not rows:
return
cols = [
"run_id",
"ts_utc",
"code_version",
"config_path",
"start",
"end",
"seed",
"trial",
"jobs",
"ann_return",
"ann_vol",
"max_drawdown",
"sharpe",
"trades_per_year",
*param_cols,
]
q = ",".join(["?"] * len(cols))
join_cols = ",".join(cols)
sql = f"INSERT INTO trials ({join_cols}) VALUES ({q})"
vals = []
for r in rows:
vals.append([r.get(c) for c in cols])
with sqlite3.connect(str(db_path)) as con:
con.executemany(sql, vals)
con.commit()
def reservoir_sample_product(rng, iterables, k: int):
"""Sample up to k combos from cartesian product."""
import itertools
sample = []
n = 0
for combo in itertools.product(*iterables):
n += 1
if len(sample) < k:
sample.append(combo)
else:
j = rng.randrange(n)
if j < k:
sample[j] = combo
return sample
def _init_globals(prices: dict[str, pd.DataFrame], universe: list[UniverseAsset], constraints: Constraints, risk_proxy: str, rates_fallback: str) -> None:
global _G_PRICES, _G_UNIVERSE, _G_CONSTRAINTS, _G_RISK_PROXY, _G_RATES_FALLBACK
_G_PRICES = prices
_G_UNIVERSE = universe
_G_CONSTRAINTS = constraints
_G_RISK_PROXY = risk_proxy
_G_RATES_FALLBACK = rates_fallback
def _eval_one(task: dict[str, Any]) -> dict[str, Any] | None:
assert _G_PRICES is not None
assert _G_UNIVERSE is not None
assert _G_CONSTRAINTS is not None
assert _G_RISK_PROXY is not None
assert _G_RATES_FALLBACK is not None
params = TrendParams()
params = replace(params, **task["params"])
try:
equity, _w, tr = run_backtest(
_G_PRICES,
_G_UNIVERSE,
_G_CONSTRAINTS,
params,
rates_fallback=_G_RATES_FALLBACK,
risk_proxy=_G_RISK_PROXY,
)
except Exception:
return None
st = perf_stats(equity["equity"])
if not st:
return None
tpy = trades_per_year(tr, task["start"], task["end"])
if tpy > float(task["max_trades_per_year"]):
return None
row = {**st, "trades_per_year": float(tpy), **asdict(params)}
row["trial"] = int(task["trial"])
row["seed"] = int(task["seed"])
return row
MAX_GRID_COMBOS = 128
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--config", default="configs/etf_universe_industry_profiled.json")
ap.add_argument("--rawdir", default="data/raw")
ap.add_argument("--start", default="20200101")
ap.add_argument("--end", default="20251231")
ap.add_argument("--trials", type=int, default=240)
ap.add_argument("--mode", choices=["random", "grid"], default="random")
ap.add_argument("--max_grid", type=int, default=MAX_GRID_COMBOS)
ap.add_argument("--seed", type=int, default=1)
ap.add_argument("--jobs", type=int, default=1, help="Parallel workers (processes), up to 8")
ap.add_argument("--state", default="data/opt_state.json")
ap.add_argument("--db", default="data/experiments.sqlite")
ap.add_argument("--baseline", type=float, default=None)
ap.add_argument("--report_step", type=float, default=0.05)
ap.add_argument("--max_trades_per_year", type=float, default=80.0)
ap.add_argument("--progress_every", type=int, default=25)
args = ap.parse_args()
jobs = max(1, min(8, int(args.jobs)))
random.seed(args.seed)
np.random.seed(args.seed)
config_path = Path(args.config)
universe, constraints, risk_proxy, rates_fallback = load_universe(config_path)
prices = load_prices(Path(args.rawdir), universe, args.start, args.end)
_init_globals(prices, universe, constraints, risk_proxy, rates_fallback)
state_path = Path(args.state)
state = load_state(state_path)
best = state.get("best")
best_ann = float(best["ann_return"]) if best else float("-inf")
baseline = args.baseline
if baseline is None:
baseline = best_ann if np.isfinite(best_ann) else 0.0
last_rep = state.get("last_reported_ann_return")
if last_rep is None:
last_rep = baseline
params0 = TrendParams(max_positions=constraints.max_positions)
params0_dict = asdict(params0)
# Parameter columns to persist as structured fields in SQLite
param_cols = sorted(params0_dict.keys())
db_path = Path(args.db)
ensure_db(db_path, param_cols=param_cols)
run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + f"_seed{int(args.seed)}"
code_version = infer_code_version(Path("."))
tasks: list[dict[str, Any]] = []
rng = random.Random(int(args.seed))
if str(args.mode) == "grid":
grids = {
"sma_fast": [3, 5],
"sma_slow": [15, 20, 30],
"lazy_days": [4, 5, 6, 8],
"min_hold_days": [2, 3, 5],
"replace_score_gap": [0.5, 0.8, 1.2, 1.6],
"min_score": [0.0, 0.2, 0.4, 0.6],
"desired_positions_min": [1, 2],
"macro_min_breadth": [0.10, 0.15, 0.20, 0.30],
"macro_down_frac": [0.75, 0.80, 0.85],
"atr_mult": [2.5, 3.2, 4.0],
"stop_loss_atr": [2.0, 2.5, 3.2],
"profit_tighten_atr": [4.0, 6.0, 8.0],
"atr_mult_profit": [1.5, 2.0, 2.5],
"bias_exit": [0.12, 0.18, 0.25],
"vol_ratio_exit": [3.0, 4.0],
}
keys = list(grids.keys())
iters = [list(grids[k]) for k in keys]
total = 1
for xs in iters:
total *= max(1, len(xs))
max_grid = max(1, int(args.max_grid))
if total > max_grid:
print(f"grid combos {total} > {max_grid}; sampling combos", flush=True)
combos = reservoir_sample_product(rng, iters, max_grid)
else:
import itertools
combos = list(itertools.product(*iters))
for t, combo in enumerate(combos):
vals = dict(zip(keys, combo))
sma_fast = int(vals["sma_fast"])
sma_slow = int(vals["sma_slow"])
if sma_fast >= sma_slow:
continue
p = replace(
params0,
sma_fast=sma_fast,
sma_slow=sma_slow,
lazy_days=int(vals["lazy_days"]),
min_hold_days=int(vals["min_hold_days"]),
replace_score_gap=float(vals["replace_score_gap"]),
min_score=float(vals["min_score"]),
desired_positions_min=int(vals["desired_positions_min"]),
desired_positions_max=int(3),
macro_min_breadth=float(vals["macro_min_breadth"]),
macro_down_frac=float(vals["macro_down_frac"]),
atr_mult=float(vals["atr_mult"]),
stop_loss_atr=float(vals["stop_loss_atr"]),
profit_tighten_atr=float(vals["profit_tighten_atr"]),
atr_mult_profit=float(vals["atr_mult_profit"]),
bias_exit=float(vals["bias_exit"]),
vol_ratio_exit=float(vals["vol_ratio_exit"]),
rebalance_every=1,
)
tasks.append({
"trial": int(t),
"seed": int(args.seed),
"start": str(args.start),
"end": str(args.end),
"max_trades_per_year": float(args.max_trades_per_year),
"params": {k: asdict(p)[k] for k in param_cols},
})
else:
for t in range(int(args.trials)):
sma_fast = rng.choice([3, 5])
sma_slow = rng.choice([15, 20, 30])
if sma_fast >= sma_slow:
continue
lazy_days = rng.choice([4, 5, 6, 8])
min_hold = rng.choice([2, 3, 5])
replace_gap = rng.choice([0.5, 0.8, 1.2, 1.6])
min_score = rng.choice([0.0, 0.2, 0.4, 0.6])
dmin = rng.choice([1, 2])
dmax = 3
macro_min_breadth = rng.choice([0.10, 0.15, 0.20, 0.30])
macro_down_frac = rng.choice([0.75, 0.80, 0.85])
atr_mult = rng.choice([2.5, 3.2, 4.0])
stop_loss_atr = rng.choice([2.0, 2.5, 3.2])
profit_tighten_atr = rng.choice([4.0, 6.0, 8.0])
atr_mult_profit = rng.choice([1.5, 2.0, 2.5])
bias_exit = rng.choice([0.12, 0.18, 0.25])
vol_ratio_exit = rng.choice([3.0, 4.0])
p = replace(params0, sma_fast=int(sma_fast), sma_slow=int(sma_slow), lazy_days=int(lazy_days), min_hold_days=int(min_hold), replace_score_gap=float(replace_gap), min_score=float(min_score), desired_positions_min=int(dmin), desired_positions_max=int(dmax), macro_min_breadth=float(macro_min_breadth), macro_down_frac=float(macro_down_frac), atr_mult=float(atr_mult), stop_loss_atr=float(stop_loss_atr), profit_tighten_atr=float(profit_tighten_atr), atr_mult_profit=float(atr_mult_profit), bias_exit=float(bias_exit), vol_ratio_exit=float(vol_ratio_exit), rebalance_every=1)
tasks.append({"trial": int(t), "seed": int(args.seed), "start": str(args.start), "end": str(args.end), "max_trades_per_year": float(args.max_trades_per_year), "params": {k: asdict(p)[k] for k in param_cols}})
results: list[dict[str, Any]] = []
rows_for_db: list[dict[str, Any]] = []
def record_row(row: dict[str, Any]) -> None:
nonlocal best_ann
results.append(row)
if float(row["ann_return"]) > best_ann:
best_ann = float(row["ann_return"])
state["best"] = row
save_state(state_path, state)
db_row = {
"run_id": run_id,
"ts_utc": datetime.now(timezone.utc).isoformat(),
"code_version": code_version,
"config_path": str(config_path),
"start": str(args.start),
"end": str(args.end),
"seed": int(args.seed),
"trial": int(row.get("trial", -1)),
"jobs": int(jobs),
"ann_return": float(row["ann_return"]),
"ann_vol": float(row["ann_vol"]),
"max_drawdown": float(row["max_drawdown"]),
"sharpe": float(row["sharpe"]),
"trades_per_year": float(row["trades_per_year"]),
}
for c in param_cols:
db_row[c] = row.get(c)
rows_for_db.append(db_row)
if len(rows_for_db) >= 200:
insert_rows(db_path, param_cols=param_cols, rows=rows_for_db)
rows_for_db.clear()
if jobs == 1:
for task in tasks:
row = _eval_one(task)
if row is None:
continue
record_row(row)
if int(args.progress_every) > 0 and (len(results) % int(args.progress_every) == 0):
print(f"progress valid={len(results)} best_ann={best_ann:.4f}", flush=True)
else:
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
ctx = mp.get_context("fork")
with ProcessPoolExecutor(max_workers=jobs, mp_context=ctx) as ex:
futs = [ex.submit(_eval_one, task) for task in tasks]
for fut in as_completed(futs):
row = fut.result()
if row is None:
continue
record_row(row)
if int(args.progress_every) > 0 and (len(results) % int(args.progress_every) == 0):
print(f"progress valid={len(results)} best_ann={best_ann:.4f}", flush=True)
if rows_for_db:
insert_rows(db_path, param_cols=param_cols, rows=rows_for_db)
rows_for_db.clear()
state["history"].append(
{
"timestamp": datetime.now(timezone.utc).isoformat(),
"run_id": run_id,
"code_version": code_version,
"config": str(args.config),
"start": str(args.start),
"end": str(args.end),
"trials": int(args.trials),
"jobs": int(jobs),
"best_ann_return": float(best_ann) if np.isfinite(best_ann) else None,
"db": str(args.db),
}
)
save_state(state_path, state)
if not results:
print("no valid trials")
return
df = pd.DataFrame(results).sort_values(["ann_return"], ascending=False)
cols = [
"ann_return",
"ann_vol",
"max_drawdown",
"sharpe",
"trades_per_year",
"sma_fast",
"sma_slow",
"lazy_days",
"min_hold_days",
"replace_score_gap",
"min_score",
"macro_min_breadth",
"macro_down_frac",
"desired_positions_min",
"atr_mult",
"stop_loss_atr",
"profit_tighten_atr",
"atr_mult_profit",
"bias_exit",
"vol_ratio_exit",
]
cols = [c for c in cols if c in df.columns]
print(df[cols].head(12).to_string(index=False))
if best_ann >= float(last_rep) + float(args.report_step):
state["last_reported_ann_return"] = float(best_ann)
save_state(state_path, state)
print("REPORT_TRIGGER", float(best_ann), "baseline", float(last_rep))
if __name__ == "__main__":
main()