opto.trainer.algorithms.basic_algorithms¶
Minibatch ¶
Bases: Trainer
General minibatch optimization algorithm. This class defines a general training and logging routine using minimbatch sampling.
train ¶
train(
guide,
train_dataset,
*,
ensure_improvement: bool = False,
improvement_threshold: float = 0.0,
num_epochs: int = 1,
batch_size: int = 1,
test_dataset=None,
eval_frequency: int = 1,
num_eval_samples: int = 1,
log_frequency: Union[int, None] = None,
save_frequency: Union[int, None] = None,
save_path: str = "checkpoints/agent.pkl",
min_score: Union[int, None] = None,
verbose: Union[bool, str] = False,
num_threads: int = None,
**kwargs
)
Given a dataset of (x, info) pairs, the algorithm will: 1. Forward the agent on the inputs and compute the feedback using the guide. 2. Update the agent using the feedback. 3. Evaluate the agent on the test dataset and log the results.
evaluate ¶
evaluate(
agent,
guide,
xs,
infos,
min_score=None,
num_samples=1,
num_threads=None,
description=None,
)
Evaluate the agent on the given dataset.
has_improvement ¶
has_improvement(
xs,
guide,
infos,
current_score,
current_outputs,
backup_dict,
threshold=0,
num_threads=None,
*args,
**kwargs
)
Check if the updated agent is improved compared to the current one.
Args: xs: inputs infos: additional information for the guide current_score: current score of the agent current_outputs: outputs of the agent, guide interaction backup_dict: backup of the current value of the parameters improvement_threshold: threshold for improvement num_threads: maximum number of threads to use
forward ¶
Forward the agent on the input and compute the feedback using the guide. Args: agent: trace.Module x: input guide: (question, student_answer, info) -> score, feedback info: additional information for the guide Returns: outputs that will be used to update the agent
update ¶
Subclasses can implement this method to update the agent. Args: outputs: returned value from self.step verbose: whether to print the output of the agent num_threads: maximum number of threads to use (overrides self.num_threads) Returns: score: average score of the minibatch of inputs
MinibatchAlgorithm ¶
Bases: Minibatch
The computed output of each instance in the minibatch is aggregated and a batched feedback is provided to update the agent.
update ¶
Subclasses can implement this method to update the agent. Args: outputs: returned value from self.step verbose: whether to print the output of the agent num_threads: maximum number of threads to use (overrides self.num_threads) Returns: score: average score of the minibatch of inputs
optimizer_step ¶
Subclasses can implement this method to update the agent.
BasicSearchAlgorithm ¶
Bases: MinibatchAlgorithm
A basic search algorithm that calls the optimizer multiple times to get candidates and selects the best one based on validation set.
train ¶
train(
guide,
train_dataset,
*,
validate_dataset=None,
validate_guide=None,
num_proposals=4,
num_epochs=1,
batch_size=1,
test_dataset=None,
eval_frequency=1,
log_frequency=None,
min_score=None,
verbose=False,
num_threads=None,
**kwargs
)
optimizer_step ¶
Use the optimizer to propose multiple updates and select the best one based on validation score.
standard_optimization_step ¶
Forward and compute feedback.
Args: agent: trace.Module x: input guide: (question, student_answer, info) -> score, feedback info: additional information for the guide min_score: minimum score when exception happens
Returns: target: output of the agent score: score from the guide feedback: feedback from the guide