opto.features.priority_search.search_template¶
Samples ¶
A container for samples collected during the search algorithm. It contains a list of RolloutsGraph objects and a dataset with inputs and infos which created the list of RolloutsGraph.
SearchTemplate ¶
Bases: Minibatch
This implements a generic template for search algorithm.
train ¶
train(
guide,
train_dataset,
*,
validate_dataset=None,
validate_guide=None,
batch_size=1,
sub_batch_size=None,
score_range=None,
num_epochs=1,
num_threads=None,
verbose=False,
test_dataset=None,
test_guide=None,
eval_frequency: Union[int, None] = 1,
num_eval_samples: int = 1,
log_frequency=None,
save_frequency: Union[int, None] = None,
save_path: str = "checkpoints/agent.pkl",
**kwargs
)
sample ¶
Sample a batch of data based on the proposed parameters. All proposals are evaluated on the same batch of inputs.
Args: agents (list): A list of trace.Modules (proposed parameters) to evaluate. **kwargs: Additional keyword arguments that may be used by the implementation.
update ¶
Update the agent based on the provided samples. Args: samples (list): A list of samples from the previous iteration. If None, the agent's parameters are returned without updating. verbose (bool, optional): Whether to print verbose output. Defaults to False. **kwargs: Additional keyword arguments that may be used by the implementation. Returns: update_dict (dict of Parameter: Any): A dictionary containing the updated parameters of the agent. proposals (list of trace.Module): A list of proposed parameters (trace.Module) after the update. info_log (dict of str: Any): A dictionary containing logging information about the update process.
This method updates the agent's parameters based on samples of the training dataset and validation dataset (provided by self.get_validate_dataset). In addition, it return new agents (proposals) that can be used for collecting data for the next iteration.