opto.trainer.algorithms.UCBsearch¶
UCBSearchAlgorithm ¶
UCBSearchAlgorithm(
agent: Module,
optimizer,
max_buffer_size: int = 10,
ucb_exploration_factor: float = 1.0,
logger=None,
num_threads: int = None,
*args,
**kwargs
)
Bases: MinibatchAlgorithm
UCB Search Algorithm.
Keeps a buffer of candidates with their statistics (score sum, evaluation count). In each iteration: 1. Picks a candidate 'a' from the buffer with the highest UCB score. 2. Updates the optimizer with 'a's parameters. 3. Draws a minibatch from the training set, performs a forward/backward pass, and calls optimizer.step() to get a new candidate 'a''. 4. Evaluates 'a'' on a validation set minibatch. 5. Updates statistics of 'a' (based on the training minibatch). 6. Adds 'a'' (with its validation stats) to the buffer. 7. If the buffer is full, evicts the candidate with the lowest UCB score.
train ¶
train(
guide,
train_dataset: Dict[str, List[Any]],
*,
validation_dataset: Optional[
Dict[str, List[Any]]
] = None,
num_search_iterations: int = 100,
train_batch_size: int = 2,
evaluation_batch_size: int = 20,
eval_frequency: int = 1,
log_frequency: Optional[int] = None,
save_frequency: Optional[int] = None,
save_path: str = "checkpoints/ucb_agent.pkl",
min_score_for_agent_update: Optional[float] = None,
verbose: Union[bool, str] = False,
num_threads: Optional[int] = None,
**kwargs
) -> Tuple[Dict[str, Any], float]
Main training loop for UCB Search Algorithm.