Skip to content

opto.trainer.algorithms.basic_algorithms

Minibatch

Minibatch(
    agent,
    optimizer,
    num_threads: int = None,
    logger=None,
    *args,
    **kwargs
)

Bases: Trainer

General minibatch optimization algorithm. This class defines a general training and logging routine using minimbatch sampling.

optimizer instance-attribute

optimizer = optimizer

n_iters instance-attribute

n_iters = 0

train

train(
    guide,
    train_dataset,
    *,
    ensure_improvement: bool = False,
    improvement_threshold: float = 0.0,
    num_epochs: int = 1,
    batch_size: int = 1,
    test_dataset=None,
    eval_frequency: int = 1,
    num_eval_samples: int = 1,
    log_frequency: Union[int, None] = None,
    save_frequency: Union[int, None] = None,
    save_path: str = "checkpoints/agent.pkl",
    min_score: Union[int, None] = None,
    verbose: Union[bool, str] = False,
    num_threads: int = None,
    **kwargs
)

Given a dataset of (x, info) pairs, the algorithm will: 1. Forward the agent on the inputs and compute the feedback using the guide. 2. Update the agent using the feedback. 3. Evaluate the agent on the test dataset and log the results.

evaluate

evaluate(
    agent,
    guide,
    xs,
    infos,
    min_score=None,
    num_samples=1,
    num_threads=None,
    description=None,
)

Evaluate the agent on the given dataset.

has_improvement

has_improvement(
    xs,
    guide,
    infos,
    current_score,
    current_outputs,
    backup_dict,
    threshold=0,
    num_threads=None,
    *args,
    **kwargs
)

Check if the updated agent is improved compared to the current one.

Args: xs: inputs infos: additional information for the guide current_score: current score of the agent current_outputs: outputs of the agent, guide interaction backup_dict: backup of the current value of the parameters improvement_threshold: threshold for improvement num_threads: maximum number of threads to use

forward

forward(agent, x, guide, info)

Forward the agent on the input and compute the feedback using the guide. Args: agent: trace.Module x: input guide: (question, student_answer, info) -> score, feedback info: additional information for the guide Returns: outputs that will be used to update the agent

update

update(outputs, verbose=False, num_threads=None, **kwargs)

Subclasses can implement this method to update the agent. Args: outputs: returned value from self.step verbose: whether to print the output of the agent num_threads: maximum number of threads to use (overrides self.num_threads) Returns: score: average score of the minibatch of inputs

MinibatchAlgorithm

MinibatchAlgorithm(
    agent,
    optimizer,
    num_threads: int = None,
    logger=None,
    *args,
    **kwargs
)

Bases: Minibatch

The computed output of each instance in the minibatch is aggregated and a batched feedback is provided to update the agent.

forward

forward(agent, x, guide, info)

update

update(outputs, verbose=False, num_threads=None, **kwargs)

Subclasses can implement this method to update the agent. Args: outputs: returned value from self.step verbose: whether to print the output of the agent num_threads: maximum number of threads to use (overrides self.num_threads) Returns: score: average score of the minibatch of inputs

optimizer_step

optimizer_step(
    bypassing=False,
    verbose=False,
    num_threads=None,
    **kwargs
)

Subclasses can implement this method to update the agent.

BasicSearchAlgorithm

BasicSearchAlgorithm(
    agent,
    optimizer,
    num_threads: int = None,
    logger=None,
    *args,
    **kwargs
)

Bases: MinibatchAlgorithm

A basic search algorithm that calls the optimizer multiple times to get candidates and selects the best one based on validation set.

train

train(
    guide,
    train_dataset,
    *,
    validate_dataset=None,
    validate_guide=None,
    num_proposals=4,
    num_epochs=1,
    batch_size=1,
    test_dataset=None,
    eval_frequency=1,
    log_frequency=None,
    min_score=None,
    verbose=False,
    num_threads=None,
    **kwargs
)

optimizer_step

optimizer_step(
    bypassing=False,
    verbose=False,
    num_threads=None,
    **kwargs
)

Use the optimizer to propose multiple updates and select the best one based on validation score.

standard_optimization_step

standard_optimization_step(
    agent, x, guide, info, min_score=0
)

Forward and compute feedback.

Args: agent: trace.Module x: input guide: (question, student_answer, info) -> score, feedback info: additional information for the guide min_score: minimum score when exception happens

Returns: target: output of the agent score: score from the guide feedback: feedback from the guide

batchify

batchify(*items)

Concatenate the items into a single string