Skip to content

opto.trainer.algorithms.UCBsearch

UCBSearchAlgorithm

UCBSearchAlgorithm(
    agent: Module,
    optimizer,
    max_buffer_size: int = 10,
    ucb_exploration_factor: float = 1.0,
    logger=None,
    num_threads: int = None,
    *args,
    **kwargs
)

Bases: MinibatchAlgorithm

UCB Search Algorithm.

Keeps a buffer of candidates with their statistics (score sum, evaluation count). In each iteration: 1. Picks a candidate 'a' from the buffer with the highest UCB score. 2. Updates the optimizer with 'a's parameters. 3. Draws a minibatch from the training set, performs a forward/backward pass, and calls optimizer.step() to get a new candidate 'a''. 4. Evaluates 'a'' on a validation set minibatch. 5. Updates statistics of 'a' (based on the training minibatch). 6. Adds 'a'' (with its validation stats) to the buffer. 7. If the buffer is full, evicts the candidate with the lowest UCB score.

buffer instance-attribute

buffer = deque(maxlen=max_buffer_size)

max_buffer_size instance-attribute

max_buffer_size = max_buffer_size

ucb_exploration_factor instance-attribute

ucb_exploration_factor = ucb_exploration_factor

train

train(
    guide,
    train_dataset: Dict[str, List[Any]],
    *,
    validation_dataset: Optional[
        Dict[str, List[Any]]
    ] = None,
    num_search_iterations: int = 100,
    train_batch_size: int = 2,
    evaluation_batch_size: int = 20,
    eval_frequency: int = 1,
    log_frequency: Optional[int] = None,
    save_frequency: Optional[int] = None,
    save_path: str = "checkpoints/ucb_agent.pkl",
    min_score_for_agent_update: Optional[float] = None,
    verbose: Union[bool, str] = False,
    num_threads: Optional[int] = None,
    **kwargs
) -> Tuple[Dict[str, Any], float]

Main training loop for UCB Search Algorithm.

select

select(buffer)

Could be subclassed to implement different selection strategies