core / frontier_selection.py
Pure-function weighting layer that turns PRM uncertainty + trajectory rarity into a frontier-aware cluster pick for self-play.
What it solves
The brittle-pool scoring in FrontierTracker.pick_seed sees outcomes but not coverage. A cluster the agent has barely tried looks identical to a cluster it solves first-try: both look quiet to a signal that only watches recent failures. Frontier weighting splits these by adding two complementary signals:
- Uncertainty —
PRMScorer.uncertainty(state, action)as boundary distance1 − 2·|p − 0.5|. High when the PRM scores a representative state near the decision boundary (or when the model is untrained, which returns 1.0 by contract — explicitly bias toward exploring clusters the PRM has no opinion on). - Rarity —
1 / (1 + log1p(count))overTrajectory.clustergroupings. Smooth, bounded in (0, 1], with log decay so even well-explored clusters keep some weight when their uncertainty is high.
The two are multiplied: a cluster needs both "the PRM is unsure" AND "we don't have many examples" to win. Either signal alone collapses the product, which is what we want — mastered clusters (low uncertainty) and well-explored clusters (low rarity) drop out automatically.
Design non-negotiables
- Pure functions only. No I/O, no globals, no logging side effects. The caller (typically
FrontierTracker.pick_frontier_seedviaDreamer.synthetic_self_play) does the on-disk reads and wires the result through. - Fail-safe. Every function returns a defined value for empty / NaN / inf / negative inputs.
compute_cluster_uncertaintycatches per-cluster exceptions and assigns max uncertainty (1.0) so a bad cluster gets explored rather than silently dropped. - Deterministic-with-seed sampling.
pick_weightedaccepts an injectedrandom.Randomso tests pin the picker's behaviour without monkeypatching globals. - Synthetic representative state, not real trajectory sampling. The PRM is queried with a templated request (
"solve a {cluster} challenge") per cluster. We deliberately do NOT sample real trajectories here — that would couple the seed picker to the trajectory store's read latency and re-introduce I/O. Aggregated uncertainty over real trajectories is a future refinement; the templated state is enough for the boundary-distance signal we use.
Public API
| Function | Purpose |
|---|---|
representative_state(cluster_key) → (PlanState, ActionFeatures) | Synthesise a stand-in (state, action) pair for "a fresh attempt at this cluster". |
compute_cluster_uncertainty(prm_scorer, cluster_keys) → {cluster: float} | PRM boundary distance per cluster. None scorer → 1.0 everywhere. Per-cluster exceptions → 1.0 (max uncertainty, biases toward exploration). |
compute_cluster_rarity(trajectory_counts, cluster_keys) → {cluster: float} | Log-decay rarity in (0, 1]. Missing or negative counts treated as 0 (max rarity). |
combine_weights(uncertainty, rarity, exclude=None) → {cluster: float} | Multiplicative combiner. Excluded clusters (typically saturated) get weight 0 but stay in the dict so callers can log what was filtered. NaN / inf / negative slips become 0. |
pick_weighted(weights, *, rng=None) → Optional[str] | Sample a single cluster_key in proportion to its weight. Returns None when every weight is zero (caller falls back to cold-start). |
count_trajectories_by_cluster(trajectories) → {cluster: int} | Group an iterable of Trajectory by their .cluster attribute. Trajectories with None / empty / missing cluster are skipped. |
Composition (as wired in dream.py)
candidate_clusters = sorted(set(TEMPLATES.keys()) | tracker.clusters)
counts = count_trajectories_by_cluster(
ctx.trajectory_collector.iter_trajectories()
)
unc = compute_cluster_uncertainty(ctx.prm_scorer, candidate_clusters)
rar = compute_cluster_rarity(counts, candidate_clusters)
seed = frontier_tracker.pick_frontier_seed(
uncertainty_by_cluster=unc,
rarity_by_cluster=rar,
uniform_sample_prob=args.frontier_uniform_sample_prob,
)
# seed → same dict shape as pick_seed; mode="frontier_weighted" on the new path.
Weight math (worked example)
With three clusters where SQL is rare-and-uncertain, BASH is well-explored but uncertain, ALGO is well-explored and confident:
scores uncertainty trajectory_count rarity combined
sql 0.90 1 0.594 0.535
bash 0.90 500 0.139 0.125
algo 0.05 5 0.358 0.018
SQL wins ~76% of picks, BASH ~22%, ALGO ~3%. The well-explored-but-uncertain BASH still gets meaningful airtime; the well-explored-and-confident ALGO is effectively skipped without being hard-blocked.
Sanity floor
The picker in FrontierTracker.pick_frontier_seed wraps this layer with a uniform-sample sanity roll (default 20%, exposed as --frontier-uniform-sample-prob). On the roll, the legacy pick_seed is called regardless of weights. The PRM is itself learned from trajectories the self-play loop produces — without this floor a cold bias could self-reinforce onto a single cluster forever.
Tests
tests/test_frontier_selection.py— 32 cases over the six public functions (pure math: uncertainty boundary cases, rarity decay monotonicity, weight combination, weighted sampling distribution, deterministic-with-seed behaviour, end-to-end composition).tests/test_frontier_pick_frontier_seed.py— 9 cases over the integration withFrontierTracker: empty-signal fallback, uniform-sample tagging, all-zero-weight fallback, saturated-cluster exclusion, dict-shape contract.tests/test_prm_uncertainty.py— 10 cases pinning thePRMScorer.uncertaintycontract this module depends on.tests/test_dream_frontier_weighted.py— 4 dream-integration cases (real PRM + realTrajectoryCollector, mocked LLM/sandbox): new path engages with real wiring;--no-frontier-selfplaykill switch; untrained PRM falls back; missing collector falls back.