opto.features.priority_search.sampler¶
Rollout
dataclass
¶
A rollout is a single sample from the environment. It contains the module, input, info, target, score, and feedback. This is used to store the results of the agent's evaluation on a single input.
RolloutsGraph ¶
A rollouts graph is a collection of rollouts generated by the same agent (trace.Module) on different inputs.
Initialize a rollouts graph with the given rollouts.
RolloutConfig
dataclass
¶
Sampler ¶
Sampler(
loader,
guide,
num_threads=1,
sub_batch_size=None,
forward=None,
score_range=(-np.inf, np.inf),
)
A sampler that samples a batch of data from the loader and evaluates the agents on the sampled inputs.
Initialize the sampler with a data loader and a guide.
Args: loader (DataLoader): The data loader to sample from. guide (Guide): The guide to evaluate the proposals. num_threads (int): Number of threads to use for sampling. sub_batch_size (int, optional): Size of the sub-batch to use for sampling. If None, uses the batch size. score_range (tuple): The range of scores to consider valid.
sample ¶
Sample a batch of data from the loader and evaluate the agents.
Args: agents (list): A list of trace.Modules (proposed parameters) to evaluate.
Returns: batch (dict): A dictionary containing the sampled inputs and infos, where: - 'inputs': a list of inputs sampled from the loader - 'infos': a list of additional information for each input
samples (list of RolloutsGraph):
A list of RolloutsGraph objects, each containing the rollouts generated by the agents on the sampled inputs.
Each RolloutsGraph contains:
- 'module': the trace.Module (proposal)
- 'rollouts': a list of Rollout objects containing:
- 'x': the input data
- 'info': additional information about the input
- 'target': the target output (if applicable)
- 'score': the score of the proposal
- 'feedback': the feedback from the guide
NOTE: The return might not be ordered in the same way as the agents.
standard_forward ¶
Forward and compute feedback.
Args: agent: trace.Module x: input guide: (question, student_answer, info) -> score, feedback info: additional information for the guide min_score: minimum score when exception happens
Returns: target: output of the agent score: score from the guide feedback: feedback from the guide
sample_rollouts ¶
sample_rollouts(
configs,
num_threads=1,
forward=None,
min_score=None,
description="Sampling rollouts.",
) -> List[RolloutsGraph]
Sample a batch of data based on the proposed parameters. All proposals are evaluated on the same batch of inputs.
Args: configs (List[RolloutConfig]): A list of RolloutConfig objects, each containing - module: the trace.Module (proposal) to evaluate - xs: a list of input data to evaluate the proposal on - infos: a list of additional information about the inputs - guide: the guide to evaluate the proposals num_threads (int): Number of threads to use for sampling. forward (callable, optional): A custom forward function to use instead of the default one (standard_forward). If None, the default forward function is used. min_score (float, optional): Minimum score to return when an exception occurs. If None, it defaults to 0. description (str): Description to display in the progress bar. Returns: List[RolloutsGraph]: A list of RolloutsGraph objects, one for each config