151 lines
5.4 KiB
Python
151 lines
5.4 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import sqlite3
|
||
|
|
from dataclasses import fields
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import pandas as pd
|
||
|
|
|
||
|
|
from qfr.strategy.etf_trend import Constraints, TrendParams, UniverseAsset, run_backtest
|
||
|
|
|
||
|
|
|
||
|
|
def load_universe(config_path: Path) -> tuple[list[UniverseAsset], Constraints, str, str]:
|
||
|
|
conf = json.loads(config_path.read_text(encoding="utf-8"))
|
||
|
|
universe = [UniverseAsset(**a) for a in conf["assets"]]
|
||
|
|
|
||
|
|
cons = conf.get("constraints", {})
|
||
|
|
constraints = Constraints(
|
||
|
|
max_positions=int(cons.get("max_positions", 3)),
|
||
|
|
must_commodity=int(cons.get("must_include", {}).get("commodity", 0)),
|
||
|
|
must_rates=int(cons.get("must_include", {}).get("rates", 0)),
|
||
|
|
must_equity=int(cons.get("must_include", {}).get("equity", 0)),
|
||
|
|
)
|
||
|
|
|
||
|
|
risk_proxy = cons.get("risk_proxy") or (universe[0].ts_code if universe else "510300.SH")
|
||
|
|
rates_fallback = cons.get("rates_fallback", "511010.SH")
|
||
|
|
return universe, constraints, str(risk_proxy), str(rates_fallback)
|
||
|
|
|
||
|
|
|
||
|
|
def load_prices(raw_dir: Path, universe: list[UniverseAsset], start: str, end: str) -> dict[str, pd.DataFrame]:
|
||
|
|
out: dict[str, pd.DataFrame] = {}
|
||
|
|
for a in universe:
|
||
|
|
fn = raw_dir / (a.ts_code.replace(".", "") + ".parquet")
|
||
|
|
df = pd.read_parquet(fn)
|
||
|
|
df = df.copy()
|
||
|
|
df["trade_date"] = df["trade_date"].astype(str)
|
||
|
|
df = df[(df["trade_date"] >= start) & (df["trade_date"] <= end)]
|
||
|
|
out[a.ts_code] = df
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
def perf_stats(equity: pd.Series) -> dict[str, float]:
|
||
|
|
r = equity.pct_change().dropna()
|
||
|
|
if r.empty:
|
||
|
|
return {}
|
||
|
|
ann_ret = float((equity.iloc[-1] / equity.iloc[0]) ** (252 / len(r)) - 1)
|
||
|
|
ann_vol = float(r.std(ddof=1) * (252**0.5))
|
||
|
|
dd = float((equity / equity.cummax() - 1.0).min())
|
||
|
|
sharpe = float(ann_ret / ann_vol) if ann_vol > 0 else float("nan")
|
||
|
|
return {"ann_return": ann_ret, "ann_vol": ann_vol, "max_drawdown": dd, "sharpe": sharpe}
|
||
|
|
|
||
|
|
|
||
|
|
def table_columns(con: sqlite3.Connection, table: str) -> list[str]:
|
||
|
|
return [row[1] for row in con.execute(f"PRAGMA table_info({table})")]
|
||
|
|
|
||
|
|
|
||
|
|
def fetch_topn(db_path: Path, run_id: str | None, topn: int) -> tuple[list[str], list[dict[str, Any]]]:
|
||
|
|
with sqlite3.connect(str(db_path)) as con:
|
||
|
|
cols = table_columns(con, "trials")
|
||
|
|
where = ""
|
||
|
|
params: list[Any] = []
|
||
|
|
if run_id:
|
||
|
|
where = "WHERE run_id = ?"
|
||
|
|
params.append(run_id)
|
||
|
|
sql = f"SELECT * FROM trials {where} ORDER BY ann_return DESC LIMIT ?"
|
||
|
|
rows: list[dict[str, Any]] = []
|
||
|
|
for r in con.execute(sql, [*params, int(topn)]):
|
||
|
|
rows.append(dict(zip(cols, r)))
|
||
|
|
return cols, rows
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
ap = argparse.ArgumentParser()
|
||
|
|
ap.add_argument("--db", default="data/experiments.sqlite")
|
||
|
|
ap.add_argument("--run_id", default=None)
|
||
|
|
ap.add_argument("--topn", type=int, default=10)
|
||
|
|
ap.add_argument("--config", default="configs/etf_universe_industry_profiled.json")
|
||
|
|
ap.add_argument("--rawdir", default="data/raw")
|
||
|
|
ap.add_argument("--start", default=None)
|
||
|
|
ap.add_argument("--end", default=None)
|
||
|
|
ap.add_argument("--tol", type=float, default=1e-6)
|
||
|
|
args = ap.parse_args()
|
||
|
|
|
||
|
|
db_path = Path(args.db)
|
||
|
|
cols, rows = fetch_topn(db_path, args.run_id, args.topn)
|
||
|
|
if not rows:
|
||
|
|
print("no trials found")
|
||
|
|
return
|
||
|
|
|
||
|
|
config_path = Path(args.config)
|
||
|
|
universe, constraints, risk_proxy, rates_fallback = load_universe(config_path)
|
||
|
|
|
||
|
|
tp_fields = {f.name for f in fields(TrendParams)}
|
||
|
|
# Coerce param types: sqlite stores numerics as REAL, so ints may come back as floats.
|
||
|
|
_defaults = TrendParams()
|
||
|
|
_field_types = {name: type(getattr(_defaults, name)) for name in tp_fields}
|
||
|
|
|
||
|
|
def _coerce(name: str, v):
|
||
|
|
if v is None:
|
||
|
|
return None
|
||
|
|
t = _field_types.get(name)
|
||
|
|
if t is int:
|
||
|
|
return int(round(float(v)))
|
||
|
|
if t is bool:
|
||
|
|
return bool(int(round(float(v))))
|
||
|
|
return float(v)
|
||
|
|
|
||
|
|
|
||
|
|
mismatches = 0
|
||
|
|
for idx, row in enumerate(rows, start=1):
|
||
|
|
start = str(args.start or row.get("start") or "20200101")
|
||
|
|
end = str(args.end or row.get("end") or "20251231")
|
||
|
|
|
||
|
|
prices = load_prices(Path(args.rawdir), universe, start, end)
|
||
|
|
|
||
|
|
params_dict: dict[str, Any] = {}
|
||
|
|
for k in cols:
|
||
|
|
if k in tp_fields and row.get(k) is not None:
|
||
|
|
params_dict[k] = _coerce(k, row[k])
|
||
|
|
params_dict.setdefault("max_positions", constraints.max_positions)
|
||
|
|
|
||
|
|
tp = TrendParams(**params_dict)
|
||
|
|
equity, _weights, _trades = run_backtest(
|
||
|
|
prices,
|
||
|
|
universe,
|
||
|
|
constraints,
|
||
|
|
tp,
|
||
|
|
rates_fallback=rates_fallback,
|
||
|
|
risk_proxy=risk_proxy,
|
||
|
|
)
|
||
|
|
|
||
|
|
st = perf_stats(equity["equity"])
|
||
|
|
diffs = {k: float(st[k] - float(row.get(k) or 0.0)) for k in ["ann_return", "ann_vol", "max_drawdown", "sharpe"]}
|
||
|
|
bad = any(abs(v) > float(args.tol) for v in diffs.values())
|
||
|
|
if bad:
|
||
|
|
mismatches += 1
|
||
|
|
|
||
|
|
tag = "MISMATCH" if bad else "OK"
|
||
|
|
print(f"[{idx}] {tag} id={row.get('id')} run_id={row.get('run_id')} start={start} end={end}")
|
||
|
|
print(" orig:", {k: row.get(k) for k in ["ann_return", "ann_vol", "max_drawdown", "sharpe"]})
|
||
|
|
print(" re :", st)
|
||
|
|
print(" diff:", diffs)
|
||
|
|
|
||
|
|
print(f"done. mismatches={mismatches}/{len(rows)}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|