opto.features.priority_search.search_template¶
Samples ¶
A container for samples collected during the search algorithm. It contains a list of RolloutsGraph objects and a dataset with inputs and infos which created the list of RolloutsGraph.
SearchTemplate ¶
              Bases: Minibatch
This implements a generic template for search algorithm.
train ¶
train(
    guide,
    train_dataset,
    *,
    validate_dataset=None,
    validate_guide=None,
    batch_size=1,
    sub_batch_size=None,
    score_range=None,
    num_epochs=1,
    num_threads=None,
    verbose=False,
    test_dataset=None,
    test_guide=None,
    eval_frequency: Union[int, None] = 1,
    num_eval_samples: int = 1,
    log_frequency=None,
    save_frequency: Union[int, None] = None,
    save_path: str = "checkpoints/agent.pkl",
    **kwargs
)
sample ¶
Sample a batch of data based on the proposed parameters. All proposals are evaluated on the same batch of inputs.
Args: agents (list): A list of trace.Modules (proposed parameters) to evaluate. **kwargs: Additional keyword arguments that may be used by the implementation.
update ¶
Update the agent based on the provided samples. Args: samples (list): A list of samples from the previous iteration. If None, the agent's parameters are returned without updating. verbose (bool, optional): Whether to print verbose output. Defaults to False. **kwargs: Additional keyword arguments that may be used by the implementation. Returns: update_dict (dict of Parameter: Any): A dictionary containing the updated parameters of the agent. proposals (list of trace.Module): A list of proposed parameters (trace.Module) after the update. info_log (dict of str: Any): A dictionary containing logging information about the update process.
This method updates the agent's parameters based on samples of the training dataset and validation dataset (provided by self.get_validate_dataset). In addition, it return new agents (proposals) that can be used for collecting data for the next iteration.