Skip to content

opto.features.priority_search.sampler

Rollout dataclass

Rollout(
    module: Module,
    x: Any,
    info: Any,
    target: Node,
    score: float,
    feedback: Any,
)

A rollout is a single sample from the environment. It contains the module, input, info, target, score, and feedback. This is used to store the results of the agent's evaluation on a single input.

module instance-attribute

module: Module

x instance-attribute

x: Any

info instance-attribute

info: Any

target instance-attribute

target: Node

score instance-attribute

score: float

feedback instance-attribute

feedback: Any

to_dict

to_dict()

Convert the rollout to a dictionary representation.

RolloutsGraph

RolloutsGraph(rollouts)

A rollouts graph is a collection of rollouts generated by the same agent (trace.Module) on different inputs.

Initialize a rollouts graph with the given rollouts.

module instance-attribute

module: Module = module

rollouts instance-attribute

rollouts: List[Rollout] = rollouts

get_scores

get_scores()

Get the scores of the rollouts in the subgraph.

extend

extend(other)

Extend the subgraph with another subgraph.

to_list

to_list()

Convert the subgraph to a list of rollouts.

RolloutConfig dataclass

RolloutConfig(
    module: Module,
    xs: List[Any],
    infos: List[Any],
    guide: Any,
)

Initialize a rollout config with the given module, inputs, infos, and guide.

module instance-attribute

module: Module = module

xs instance-attribute

xs: List[Any] = xs

infos instance-attribute

infos: List[Any] = infos

guide instance-attribute

guide: Any = guide

Sampler

Sampler(
    loader,
    guide,
    num_threads=1,
    sub_batch_size=None,
    forward=None,
    score_range=(-np.inf, np.inf),
)

A sampler that samples a batch of data from the loader and evaluates the agents on the sampled inputs.

Initialize the sampler with a data loader and a guide.

Args: loader (DataLoader): The data loader to sample from. guide (Guide): The guide to evaluate the proposals. num_threads (int): Number of threads to use for sampling. sub_batch_size (int, optional): Size of the sub-batch to use for sampling. If None, uses the batch size. score_range (tuple): The range of scores to consider valid.

loader instance-attribute

loader = loader

guide instance-attribute

guide = guide

num_threads instance-attribute

num_threads = num_threads

sub_batch_size instance-attribute

sub_batch_size = sub_batch_size

score_range instance-attribute

score_range = score_range

forward instance-attribute

forward = standard_forward

dataset property writable

dataset

Get the dataset of the loader.

batch_size property writable

batch_size

Get the batch size of the loader.

n_epochs property

n_epochs

Get the number of epochs of the loader.

sample

sample(agents, description_prefix='')

Sample a batch of data from the loader and evaluate the agents.

Args: agents (list): A list of trace.Modules (proposed parameters) to evaluate.

Returns: batch (dict): A dictionary containing the sampled inputs and infos, where: - 'inputs': a list of inputs sampled from the loader - 'infos': a list of additional information for each input

samples (list of RolloutsGraph):
    A list of RolloutsGraph objects, each containing the rollouts generated by the agents on the sampled inputs.
    Each RolloutsGraph contains:
    - 'module': the trace.Module (proposal)
    - 'rollouts': a list of Rollout objects containing:
        - 'x': the input data
        - 'info': additional information about the input
        - 'target': the target output (if applicable)
        - 'score': the score of the proposal
        - 'feedback': the feedback from the guide

NOTE: The return might not be ordered in the same way as the agents.

standard_forward

standard_forward(agent, x, guide, info, min_score=0)

Forward and compute feedback.

Args: agent: trace.Module x: input guide: (question, student_answer, info) -> score, feedback info: additional information for the guide min_score: minimum score when exception happens

Returns: target: output of the agent score: score from the guide feedback: feedback from the guide

sample_rollouts

sample_rollouts(
    configs,
    num_threads=1,
    forward=None,
    min_score=None,
    description="Sampling rollouts.",
) -> List[RolloutsGraph]

Sample a batch of data based on the proposed parameters. All proposals are evaluated on the same batch of inputs.

Args: configs (List[RolloutConfig]): A list of RolloutConfig objects, each containing - module: the trace.Module (proposal) to evaluate - xs: a list of input data to evaluate the proposal on - infos: a list of additional information about the inputs - guide: the guide to evaluate the proposals num_threads (int): Number of threads to use for sampling. forward (callable, optional): A custom forward function to use instead of the default one (standard_forward). If None, the default forward function is used. min_score (float, optional): Minimum score to return when an exception occurs. If None, it defaults to 0. description (str): Description to display in the progress bar. Returns: List[RolloutsGraph]: A list of RolloutsGraph objects, one for each config