Skip to content

opto.trainer.algorithms.algorithm

AbstractAlgorithm

AbstractAlgorithm(agent, *args, **kwargs)

Abstract base class for all algorithms.

agent instance-attribute

agent = agent

train

train(*args, **kwargs)

Train the agent.

Trainer

Trainer(
    agent,
    num_threads: Optional[int] = None,
    logger=None,
    *args,
    **kwargs
)

Bases: AbstractAlgorithm

We define the API of algorithms to train an agent from a dataset of (x, info) pairs.

agent: trace.Module (e.g. constructed by @trace.model) teacher: (question, student_answer, info) -> score, feedback (e.g. info can contain the true answer) train_dataset: dataset of (x, info) pairs

num_threads instance-attribute

num_threads = num_threads

logger instance-attribute

logger = logger if logger is not None else DefaultLogger()

save_agent

save_agent(save_path, iteration=None)

Save the agent to the specified path.

Args: save_path: Path to save the agent to. iteration: Current iteration number (for logging purposes).

Returns: str: The path where the agent was saved.

train

train(
    guide, train_dataset, num_threads: int = None, **kwargs
)

save

save(path: str)

Save the guide to a file.

load

load(path: str)

Load the guide from a file.