PyTree API

PyTree API#

vector.register_pytree() ReexportedPyTreeModule#

Register Optree PyTree operations for vector objects.

This module defines how vector objects are handled with the optree package. See https://blog.scientific-python.org/pytrees/ for the rationale for these functions.

After calling this function,

>>> import vector
>>> vector.register_pytree()
<module 'vector.pytree'>

the following classes can be flattened and unflattened with the optree package:

  • VectorObject*D

  • MomentumObject*D

  • VectorNumpy*D

  • MomentumNumpy*D

For example:

>>> import optree
>>> vec = vector.obj(x=1, y=2)
>>> leaves, treedef = optree.tree_flatten(vec, namespace="vector")
>>> vec2 = optree.tree_unflatten(treedef, leaves)
>>> assert vec == vec2

As a convenience, we return a re-exported module that can be used without the namespace argument. For example:

>>> pytree = vector.register_pytree()
>>> vec = vector.obj(x=1, y=2)
>>> leaves, treedef = pytree.flatten(vec)
>>> vec2 = pytree.unflatten(treedef, leaves)
>>> assert vec == vec2

A ravel function is also added to the returned PyTree module, which can be used to flatten VectorNumpy arrays into a 1D array and reconstruct them.

>>> import numpy as np
>>> vec = vector.array({"x": np.ones(10), "y": np.ones(10)})
>>> flat, unravel = pytree.ravel(vec)
>>> assert flat.shape == (20,)
>>> vec2 = unravel(flat)
>>> assert (vec == vec2).all()

Note that this function requires the optree package to be installed.