from __future__ import annotations import argparse import json import sqlite3 from dataclasses import fields from pathlib import Path from typing import Any import pandas as pd from qfr.strategy.etf_trend import Constraints, TrendParams, UniverseAsset, run_backtest def load_universe(config_path: Path) -> tuple[list[UniverseAsset], Constraints, str, str]: conf = json.loads(config_path.read_text(encoding="utf-8")) universe = [UniverseAsset(**a) for a in conf["assets"]] cons = conf.get("constraints", {}) constraints = Constraints( max_positions=int(cons.get("max_positions", 3)), must_commodity=int(cons.get("must_include", {}).get("commodity", 0)), must_rates=int(cons.get("must_include", {}).get("rates", 0)), must_equity=int(cons.get("must_include", {}).get("equity", 0)), ) risk_proxy = cons.get("risk_proxy") or (universe[0].ts_code if universe else "510300.SH") rates_fallback = cons.get("rates_fallback", "511010.SH") return universe, constraints, str(risk_proxy), str(rates_fallback) def load_prices(raw_dir: Path, universe: list[UniverseAsset], start: str, end: str) -> dict[str, pd.DataFrame]: out: dict[str, pd.DataFrame] = {} for a in universe: fn = raw_dir / (a.ts_code.replace(".", "") + ".parquet") df = pd.read_parquet(fn) df = df.copy() df["trade_date"] = df["trade_date"].astype(str) df = df[(df["trade_date"] >= start) & (df["trade_date"] <= end)] out[a.ts_code] = df return out def perf_stats(equity: pd.Series) -> dict[str, float]: r = equity.pct_change().dropna() if r.empty: return {} ann_ret = float((equity.iloc[-1] / equity.iloc[0]) ** (252 / len(r)) - 1) ann_vol = float(r.std(ddof=1) * (252**0.5)) dd = float((equity / equity.cummax() - 1.0).min()) sharpe = float(ann_ret / ann_vol) if ann_vol > 0 else float("nan") return {"ann_return": ann_ret, "ann_vol": ann_vol, "max_drawdown": dd, "sharpe": sharpe} def table_columns(con: sqlite3.Connection, table: str) -> list[str]: return [row[1] for row in con.execute(f"PRAGMA table_info({table})")] def fetch_topn(db_path: Path, run_id: str | None, topn: int) -> tuple[list[str], list[dict[str, Any]]]: with sqlite3.connect(str(db_path)) as con: cols = table_columns(con, "trials") where = "" params: list[Any] = [] if run_id: where = "WHERE run_id = ?" params.append(run_id) sql = f"SELECT * FROM trials {where} ORDER BY ann_return DESC LIMIT ?" rows: list[dict[str, Any]] = [] for r in con.execute(sql, [*params, int(topn)]): rows.append(dict(zip(cols, r))) return cols, rows def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--db", default="data/experiments.sqlite") ap.add_argument("--run_id", default=None) ap.add_argument("--topn", type=int, default=10) ap.add_argument("--config", default="configs/etf_universe_industry_profiled.json") ap.add_argument("--rawdir", default="data/raw") ap.add_argument("--start", default=None) ap.add_argument("--end", default=None) ap.add_argument("--tol", type=float, default=1e-6) args = ap.parse_args() db_path = Path(args.db) cols, rows = fetch_topn(db_path, args.run_id, args.topn) if not rows: print("no trials found") return config_path = Path(args.config) universe, constraints, risk_proxy, rates_fallback = load_universe(config_path) tp_fields = {f.name for f in fields(TrendParams)} # Coerce param types: sqlite stores numerics as REAL, so ints may come back as floats. _defaults = TrendParams() _field_types = {name: type(getattr(_defaults, name)) for name in tp_fields} def _coerce(name: str, v): if v is None: return None t = _field_types.get(name) if t is int: return int(round(float(v))) if t is bool: return bool(int(round(float(v)))) return float(v) mismatches = 0 for idx, row in enumerate(rows, start=1): start = str(args.start or row.get("start") or "20200101") end = str(args.end or row.get("end") or "20251231") prices = load_prices(Path(args.rawdir), universe, start, end) params_dict: dict[str, Any] = {} for k in cols: if k in tp_fields and row.get(k) is not None: params_dict[k] = _coerce(k, row[k]) params_dict.setdefault("max_positions", constraints.max_positions) tp = TrendParams(**params_dict) equity, _weights, _trades = run_backtest( prices, universe, constraints, tp, rates_fallback=rates_fallback, risk_proxy=risk_proxy, ) st = perf_stats(equity["equity"]) diffs = {k: float(st[k] - float(row.get(k) or 0.0)) for k in ["ann_return", "ann_vol", "max_drawdown", "sharpe"]} bad = any(abs(v) > float(args.tol) for v in diffs.values()) if bad: mismatches += 1 tag = "MISMATCH" if bad else "OK" print(f"[{idx}] {tag} id={row.get('id')} run_id={row.get('run_id')} start={start} end={end}") print(" orig:", {k: row.get(k) for k in ["ann_return", "ann_vol", "max_drawdown", "sharpe"]}) print(" re :", st) print(" diff:", diffs) print(f"done. mismatches={mismatches}/{len(rows)}") if __name__ == "__main__": main()