Skip to content

opto.features.priority_search.search_template

Samples

Samples(
    samples: List[RolloutsGraph],
    dataset: Dict[str, List[Any]],
)

A container for samples collected during the search algorithm. It contains a list of RolloutsGraph objects and a dataset with inputs and infos which created the list of RolloutsGraph.

samples instance-attribute

samples: List[RolloutsGraph] = samples

dataset instance-attribute

dataset: Dict[str, List[Any]] = dataset

n_sub_batches property

n_sub_batches: int

Number of sub-batches in the samples.

add_samples

add_samples(samples)

Add samples to the Samples object.

get_batch

get_batch()

SearchTemplate

SearchTemplate(
    agent,
    optimizer,
    num_threads: int = None,
    logger=None,
    *args,
    **kwargs
)

Bases: Minibatch

This implements a generic template for search algorithm.

max_score property

max_score

Maximum score that can be achieved by the agent.

min_score property

min_score

Minimum score that can be achieved by the agent.

train

train(
    guide,
    train_dataset,
    *,
    validate_dataset=None,
    validate_guide=None,
    batch_size=1,
    sub_batch_size=None,
    score_range=None,
    num_epochs=1,
    num_threads=None,
    verbose=False,
    test_dataset=None,
    test_guide=None,
    eval_frequency: Union[int, None] = 1,
    num_eval_samples: int = 1,
    log_frequency=None,
    save_frequency: Union[int, None] = None,
    save_path: str = "checkpoints/agent.pkl",
    **kwargs
)

sample

sample(agents, verbose=False, **kwargs)

Sample a batch of data based on the proposed parameters. All proposals are evaluated on the same batch of inputs.

Args: agents (list): A list of trace.Modules (proposed parameters) to evaluate. **kwargs: Additional keyword arguments that may be used by the implementation.

log

log(info_log, prefix='')

Log the information from the algorithm.

test

test(test_dataset, guide)

save

save(save_path)

update

update(samples=None, verbose=False, **kwargs)

Update the agent based on the provided samples. Args: samples (list): A list of samples from the previous iteration. If None, the agent's parameters are returned without updating. verbose (bool, optional): Whether to print verbose output. Defaults to False. **kwargs: Additional keyword arguments that may be used by the implementation. Returns: update_dict (dict of Parameter: Any): A dictionary containing the updated parameters of the agent. proposals (list of trace.Module): A list of proposed parameters (trace.Module) after the update. info_log (dict of str: Any): A dictionary containing logging information about the update process.

This method updates the agent's parameters based on samples of the training dataset and validation dataset (provided by self.get_validate_dataset). In addition, it return new agents (proposals) that can be used for collecting data for the next iteration.