Skip to content

opto.trace.nodes

NAME_SCOPES module-attribute

NAME_SCOPES = []

GRAPH module-attribute

GRAPH = Graph()

USED_NODES module-attribute

USED_NODES = ContextVar('USED_NODES', default=list())

T module-attribute

T = TypeVar('T')

IDENTITY_OPERATORS module-attribute

IDENTITY_OPERATORS = ('identity', 'clone')

x module-attribute

x = node('Node X')

y module-attribute

y = node('Node Y')

z module-attribute

z = MessageNode(
    "Node Z",
    inputs={"x": x, "y": y},
    description="[Add] This is an add operator of x and y.",
)

Graph

Graph()

Graph is a registry of all the nodes, forming a Directed Acyclic Graph (DAG).

Attributes: _nodes (defaultdict): An instance-level attribute, which is a defaultdict of lists, used as a lookup table to find nodes by name.

Notes: The Graph class manages and organizes nodes in a Directed Acyclic Graph (DAG). It provides methods to register nodes, clear the graph, retrieve nodes by name, and identify root nodes. The register method assumes that elements in _nodes are never removed, which is important for maintaining the integrity of node names.

Initialize the Graph object.

The initialization sets up the _nodes attribute as a defaultdict of lists to store nodes by their names.

TRACE class-attribute instance-attribute

TRACE = True

roots property

roots

Get all root nodes in the graph.

Returns: list: A list of all root nodes in the graph. A root node is identified by its is_root attribute.

clear

clear()

Remove all nodes from the graph.

The clear function iterates over the current nodes stored in the _nodes attribute and deletes each node. After all nodes have been deleted, it reinitializes the _nodes attribute to an empty defaultdict of lists. This ensures that the graph is completely cleared and ready to be repopulated with new nodes if necessary.

Notes: After calling clear, any references to the previously stored nodes will become invalid. The function is called in unit tests to reset the state of the graph between test cases, ensuring that each test runs with a clean slate and is not affected by the state left by previous tests.

register

register(node)

Add a node to the graph.

Args: node: The node object to be registered in the graph.

Notes: The register function should only be called after the node has been properly initialized and its name has been set. The function assumes that elements in the _nodes dictionary never get removed. After checking that the input is a Node and its name has the right format, the function splits the name of the node into the name variable and the identifier. The function then checks if there are any name scopes defined in the NAME_SCOPES list. If the length of the list is greater than 0, the name is prefixed with the last scope in the list followed by a "/". This allows for scoping of node names. Finally, the function adds the node to the _nodes dictionary using the modified name as the key. The _name attribute of the node is set to the modified name followed by the index of the node in the list of nodes with the same name.

get

get(name)

Retrieve a node from the graph by its name.

Args: name (str): A string in the format "name:id", where "name" is the name of the node and "id" is the identifier of the node.

Returns: Node: The requested node from the graph.

Notes: Ensure that the 'name' parameter is correctly formatted as "name:id" before calling this function. The function assumes that the '_nodes' attribute is a dictionary where each key is a node name and the corresponding value is a list of nodes. The 'id' should be a valid index within the list of nodes for the given 'name'.

AbstractNode

AbstractNode(value, *, name=None, trainable=False)

Bases: Generic[T]

AbstractNode represents an abstract data node in a directed graph.

Notes: The AbstractNode class is meant to be subclassed and extended to create specific types of nodes. The node can have multiple parents and children, forming a directed graph structure. The node has a name, which is used to identify it within the graph. The py_name attribute is the same as the name attribute, but with the ":" character removed.

The node can be initialized with a value, an optional name, and an optional trainable flag.
If the value is an instance of the `Node` class, the node will be initialized as a reference to that node, otherwise, the value will be stored directly in the node.
The default name is generated based on the type of the value and a version number which serves as the identifier, separated by ":".

The `AbstractNode` class provides several properties to access its attributes. The `data` property allows access to the stored data.
If the node is being traced within a context, the `data` property adds the node to the list of nodes used in that context.
The `parents` property returns a list of parent nodes, and the `children` property returns a list of child nodes.
The `name` property returns the name of the node, and the `py_name` property returns the name without the ":" character.
The `id` property returns the version number/identifier extracted from the name.
The `level` property returns the level of the node in the DAG.
The `is_root` property returns True if the node has no parents, and the `is_leaf` property returns True if the node has no children.

The `AbstractNode` class also provides internal methods to add parents and children to the node.
The `_add_child` method adds a child node to the node's list of children.
The `_add_parent` method adds a parent node to the node's list of parents and updates the level of the node based on the parent's level.

The `AbstractNode` class overrides the `__str__` method to provide a string representation of the node. The representation includes the name, the type of the data, and the data itself.
The `AbstractNode` class implements the `__deepcopy__` method to create a deep copy of the node. This allows the node to be detached from the original graph.
The `AbstractNode` class provides comparison methods `lt` and `gt` to compare the levels of two nodes in the DAG.

Initialize an instance of the AbstractNode class.

Args: value: The value to be assigned to the node. name (str, optional): The name of the node. Defaults to None. trainable (bool, optional): Whether the node is trainable or not. Defaults to False.

Notes: During initialization, this function generates a default name for the node based on the type of the value parameter. If the name parameter is provided, it is appended to the default name. The format of the name is "type:version", where the version is set to 0 if no name is provided. If the value parameter is an instance of the Node class, the _data attribute of the current node is set to the _data attribute of the value parameter, and the _name attribute is set to the _name attribute of the value parameter if no name is provided. Otherwise, the _data attribute is set to the value parameter itself, and the _name attribute is set to the default name. Finally, the function calls the register function of the GRAPH object to register the current node in the graph.

data property

data

Retrieve the internal data of a node.

Returns: Any: The internal data stored in the node.

Notes: If within a trace_nodes context and GRAPH.TRACE is True, adds the node to USED_NODES. This function assumes that the "_data" attribute exists within the node object. If this attribute is not present, an AttributeError will be raised.

parents property

parents

Get the parents of a node.

Returns: list: The list of parent nodes.

Notes: This property is an essential part of the graph structure and is used in various operations such as graph traversal and feedback propagation.

children property

children

Get the children of a node.

Returns: list: The list of child nodes.

Notes: This property is essential for accessing the hierarchical structure of nodes, allowing traversal and manipulation of the DAG.

name property

name

Get the name of the node.

Returns: str: The name of the node.

Notes: This property is set when the node is registered in the graph. It is a combination of the node's name and its index in the list of nodes with the same name. The index is incremented each time a new node with the same name is registered. This assumes that elements in the _nodes dictionary of the graph never get removed.

py_name property

py_name

Get the Python-friendly name of the node.

Returns: str: The name of the node with ":" characters removed.

id property

id

Get the identifier part of the node's name.

Returns: str: The identifier portion of the node's name (part after the colon).

Notes: The name property is a string formatted as "name:identifier". This property splits that string using the colon (":") delimiter and returns the second part, which corresponds to the identifier. Ensure that the name attribute contains a colon (":") to avoid index errors during the split operation.

level property

level

Get the level of the node in the graph.

Returns: int: The level of the node.

Notes: The level is determined by the maximum level of its parents plus one. The level of a root node is 0.

is_root property

is_root

Check if the node is a root node.

Returns: bool: True if the node has no parents, False otherwise.

is_leaf property

is_leaf

Check if the node is a leaf node.

Returns: bool: True if the node has no children, False otherwise.

lt

lt(other)

Compare if this node's level is less than another node's level.

Args: other: The other node to compare against.

Returns: bool: True if this node's level is less than the other node's level.

Notes: This method is used to compare the levels of two nodes in the DAG. Therefore it checks if the negated level of the current node (-self._level) is less than the negated level of the other node (-other._level)

gt

gt(other)

Compare if this node's level is greater than another node's level.

Args: other: The other node to compare against.

Returns: bool: True if this node's level is greater than the other node's level.

Notes: This method is used to compare the levels of two nodes in the DAG. Therefore it checks if the negated level of the current node (-self._level) is greater than the negated level of the other node (-other._level)

NodeVizStyleGuide

NodeVizStyleGuide(style='default', print_limit=100)

A class to provide a standardized way to visualize nodes in a graph.

Attributes: style (str): Defines the style of the visualization. Default is 'default'. print_limit (int): Sets the maximum number of characters to print for node descriptions and content. Default is 100.

Initialize the NodeVizStyleGuide.

Args: style (str, optional): The style of visualization to use. Defaults to 'default'. print_limit (int, optional): Maximum characters to print for descriptions and content. Defaults to 100.

style instance-attribute

style = style

print_limit instance-attribute

print_limit = print_limit

get_attrs

get_attrs(x)

Get the attributes for a node based on the style guide.

Args: x: The node for which attributes are to be generated.

Returns: dict: Dictionary of visualization attributes for the node.

Notes: The attributes include the label, shape, fill color, and style of the node, which are determined based on the node's properties and the style guide. The method calls other helper methods to construct the label, determine the node shape, assign a color, and set the style.

get_label

get_label(x)

Construct a label for a node.

Args: x: The node for which the label is to be constructed.

Returns: str: The constructed label string.

Notes: Using a colon in the name can cause problems in graph visualization tools like Graphviz. To avoid issues, the label is constructed by combining the node's Python name, truncated description, and content. If the description or content exceeds the print limit, it is truncated and appended with an ellipsis.

get_node_shape

get_node_shape(x)

Determine the shape of a node.

Args: x: The node for which the shape is to be determined.

Returns: str: The shape to use for the node.

Notes: The shape of a node is determined based on its type. ParameterNode types are represented as 'box', while other types are represented as 'ellipse'.

get_color

get_color(x)

Assign a color to a node.

Args: x: The node for which the color is to be assigned.

Returns: str: The color to use for the node.

Notes: The color of a node is determined based on its type. ExceptionNode types are colored 'firebrick1', and ParameterNode types are colored 'lightgray'.

get_style

get_style(x)

Set the style of a node.

Args: x: The node for which the style is to be set.

Returns: str: The style string for the node.

Notes: The style of a node is set to 'filled,solid' if the node is trainable; otherwise, it returns an empty string.

NodeVizStyleGuideColorful

NodeVizStyleGuideColorful(style='default', print_limit=100)

Bases: NodeVizStyleGuide

A class to provide a colorful style guide for visualizing nodes in a graph.

Attributes: style (str): Defines the style of the visualization. Default is 'default'. print_limit (int): Sets the maximum number of characters to print for node descriptions and content. Default is 100.

Initialize the NodeVizStyleGuideColorful.

Args: style (str, optional): The style of visualization to use. Defaults to 'default'. print_limit (int, optional): Maximum characters to print for descriptions and content. Defaults to 100.

style instance-attribute

style = style

print_limit instance-attribute

print_limit = print_limit

get_attrs

get_attrs(x)

Get the attributes for a node based on the colorful style guide.

Args: x: The node for which attributes are to be generated.

Returns: dict: Dictionary of visualization attributes for the node.

Notes: The attributes include the label, shape, fill color, style, border color, and border width of the node, which are determined based on the node's properties and the style guide. The method calls other helper methods to construct the label, determine the node shape, assign a color, and set the style.

get_border_color

get_border_color(x)

Assign a border color to a node.

Args: x: The node for which the border color is to be assigned.

Returns: str: The border color to use for the node.

Notes: The border color of a node is determined based on its type. ExceptionNode types are colored 'firebrick1', and ParameterNode types are colored 'black'.

get_color

get_color(x)

Assign a fill color to a node.

Args: x: The node for which the fill color is to be assigned.

Returns: str: The fill color to use for the node.

Notes: The fill color of a node is determined based on its type. ExceptionNode types are colored 'firebrick1', and ParameterNode types are colored 'lightgray'.

get_style

get_style(x)

Set the style of a node always as if it is trainable.

Args: x: The node for which the style is to be set.

Returns: str: The style string 'filled,solid'.

Node

Node(
    value: Any,
    *,
    name: str = None,
    trainable: bool = False,
    description: str = None,
    info: Union[None, Dict] = None
)

Bases: AbstractNode[T]

A data node in a directed graph, this is a basic data structure of Trace.

Args: value (Any): The value to be assigned to the node. name (str, optional): The name of the node. trainable (bool, optional): Whether the node is trainable or not. Defaults to False. description (str, optional): String describing the node which acts as a soft constraint. Defaults to None. info (Union[None, Dict], optional): Dictionary containing additional information about the node. Defaults to None.

Attributes: trainable (bool): Whether the node is trainable or not. _feedback (dict): Dictionary of feedback from children nodes. _description (str): String describing the node. Defaults to "[Node]". _backwarded (bool): Whether the backward method has been called. _info (dict): Dictionary containing additional information about the node. _dependencies (dict): Dictionary of dependencies on parameters and expandable nodes.

Notes: The Node class extends AbstractNode to represent a data node in a directed graph. It includes attributes and methods to handle feedback, description, and dependencies. The node can be marked as trainable and store feedback from children nodes. The feedback mechanism is analogous to gradients in machine learning and propagates information back through the graph. The feedback mechanism supports non-commutative aggregation, so feedback should be handled carefully to maintain correct operation order. The node can track dependencies on parameters and expandable nodes (nodes that depend on parameters not visible in the current graph level).

trainable instance-attribute

trainable = trainable

feedback property

feedback

The feedback from children nodes.

description property

description

A textual description of the node.

op_name property

op_name

The operator type of the node, extracted from the description.

info property

info

Additional information about the node.

type property

type

The type of the data stored in the node.

parameter_dependencies property

parameter_dependencies

The depended parameters.

Notes: Ensure that the '_dependencies' attribute is properly initialized and contains a 'parameter' key with a corresponding value before calling the parameter_dependencies function to avoid potential KeyError exceptions.

expandable_dependencies property

expandable_dependencies

The depended expandable nodes.

Notes: Expandable nodes are those who depend on parameters not visible in the current graph level. Ensure that the '_dependencies' attribute is properly initialized and contains an 'expandable' key with a corresponding value before calling the expandable_dependencies function to avoid potential KeyError exceptions.

zero_feedback

zero_feedback()

Zero out the feedback of the node.

Notes: zero_feedback should be used judiciously within the feedback propagation process to avoid unintended loss of feedback data. It is specifically designed to be used after feedback has been successfully propagated to parent nodes.

backward

backward(
    feedback: Any = "",
    propagator=None,
    retain_graph=False,
    visualize=False,
    simple_visualization=True,
    reverse_plot=False,
    print_limit=100,
)

Performs a backward pass in a computational graph.

This function propagates feedback from the current node to its parents, updates the graph visualization if required, and returns the resulting graph.

Args: feedback: The feedback given to the current node. propagator: A function that takes in a node and a feedback, and returns a dict of {parent: parent_feedback}. If not provided, a default GraphPropagator object is used. retain_graph: If True, the graph will be retained after backward pass. visualize: If True, the graph will be visualized using graphviz. simple_visualization: If True, identity operators will be skipped in the visualization. reverse_plot: If True, plot the graph in reverse order (from child to parent). print_limit: The maximum number of characters to print for node descriptions and content.

Returns: digraph: The visualization graph object if visualize=True, None otherwise.

Raises: AttributeError: If the node has already been backwarded.

Notes: The function checks if the current node has already been backwarded. If it has, an AttributeError is raised. For root nodes (no parents), only visualization is performed if enabled. For non-root nodes, feedback is propagated through the graph using a priority queue to ensure correct ordering. The propagator computes feedback for parent nodes based on the current node's description, data and feedback. Visualization is handled using graphviz if enabled, with options to simplify the graph by skipping identity operators.

clone

clone()

Create and return a duplicate of the current Node object.

Returns: Node: A clone of the current node.

detach

detach()

Create and return a deep copy of the current instance of the Node class.

Returns: Node: A deep copy of the current node.

getattr

getattr(key)

Get the attribute of the node with the specified key.

Args: key: The key of the attribute to get.

Returns: Node: A node containing the requested attribute.

call

call(fun: str, *args, **kwargs)

Call the function with the specified arguments and keyword arguments.

Args: fun: The function to call. args: The arguments to pass to the function. *kwargs: The keyword arguments to pass to the function.

Returns: Node: The result of the function call wrapped in a node.

len

len()

Return the length of the node.

Returns: Node: A node containing the length value.

Notes: We overload magic methods that return a value. This method returns a MessageNode.

eq

eq(other)

Check if the node is equal to another value.

Args: other: The value to compare the node to.

Returns: Node: A node containing the comparison result.

Notes: If a logic operator is used in an if-statement, it will return a boolean value. Otherwise, it will return a MessageNode.

neq

neq(other)

Check if the node is not equal to another value.

Args: other: The value to compare the node to.

Returns: Node: A node containing the comparison result.

Notes: If a logic operator is used in an if-statement, it will return a boolean value. Otherwise, it will return a MessageNode.

format

format(*args, **kwargs)

capitalize

capitalize()

lower

lower()

upper

upper()

swapcase

swapcase()

title

title()

split

split(sep=None, maxsplit=-1)

strip

strip(chars=None)

replace

replace(old, new, count=-1)

join

join(seq)

items

items()

values

values()

keys

keys()

pop

pop(__index=-1)

append

append(*args, **kwargs)

ParameterNode

ParameterNode(
    value,
    *,
    name=None,
    trainable=True,
    description=None,
    projections=None,
    info=None
)

Bases: Node[T]

projections instance-attribute

projections = projections

MessageNode

MessageNode(
    value,
    *,
    inputs: Union[List[Node], Dict[str, Node]],
    description: str,
    name=None,
    info=None
)

Bases: Node[T]

A node representing the output of an operator.

The description string should begin with [operator_name] followed by details about the operator. When referring to inputs in the description, use either: - The keys in args (if args is a dict) - The names of the nodes in args (if args is a list)

Examples: >>> MessageNode(node_a, inputs=[node_a], >>> description="[identity] This is an identity operator.") >>> MessageNode(copy_node_a, inputs=[node_a], >>> description="[copy] This is a copy operator.") >>> MessageNode(1, inputs={'a':node_a, 'b':node_b}, >>> description="[Add] This is an add operator of a and b.")

Attributes: value: The output value of the operator

inputs property

inputs

(Union[List[Node], Dict[str, Node]]): Input nodes to the operator

hidden_dependencies property

hidden_dependencies

Returns the set of hidden dependencies that are not visible in the current graph level.

ExceptionNode

ExceptionNode(
    value: Exception,
    *,
    inputs: Union[List[Node], Dict[str, Node]],
    description: str = None,
    name=None,
    info=None
)

Bases: MessageNode[T]

Node containing the exception message.

create_feedback

create_feedback(style='simple')

node

node(data, name=None, trainable=False, description=None)

Create a Node object from data.

Args: data: The data to create the Node from. name (str, optional): The name of the Node. trainable (bool, optional): Whether the Node is trainable. Defaults to False. description (str, optional): A string describing the data.

Returns: Node: A Node object containing the data.

Notes: If trainable=True: - If data is already a Node, extracts underlying data and updates name - Creates ParameterNode with extracted data, name, trainable=True

If trainable=False:
    - If data is already a Node, returns it (with warning if name provided)
    - Otherwise creates new Node with data, name

get_op_name

get_op_name(description)

Extract the operator type from the description.

Args: description (str): A string containing the description of the node.

Returns: str: The extracted operator type.

Raises: ValueError: If the description does not contain an operator type in square brackets.

Notes: The get_op_name function takes a description as input and uses regular expression to search for the operator type enclosed in square brackets at the beginning of the description. If a match is found, the operator type is extracted and returned. Otherwise, a ValueError is raised with a specific error message.