PyTree API#
- vector.register_pytree() ReexportedPyTreeModule#
Register Optree PyTree operations for vector objects.
This module defines how vector objects are handled with the optree package. See https://blog.scientific-python.org/pytrees/ for the rationale for these functions.
After calling this function,
>>> import vector >>> vector.register_pytree() <module 'vector.pytree'>
the following classes can be flattened and unflattened with the optree package:
VectorObject*D
MomentumObject*D
VectorNumpy*D
MomentumNumpy*D
For example:
>>> import optree >>> vec = vector.obj(x=1, y=2) >>> leaves, treedef = optree.tree_flatten(vec, namespace="vector") >>> vec2 = optree.tree_unflatten(treedef, leaves) >>> assert vec == vec2
As a convenience, we return a re-exported module that can be used without the
namespaceargument. For example:>>> pytree = vector.register_pytree() >>> vec = vector.obj(x=1, y=2) >>> leaves, treedef = pytree.flatten(vec) >>> vec2 = pytree.unflatten(treedef, leaves) >>> assert vec == vec2
A ravel function is also added to the returned PyTree module, which can be used to flatten VectorNumpy arrays into a 1D array and reconstruct them.
>>> import numpy as np >>> vec = vector.array({"x": np.ones(10), "y": np.ones(10)}) >>> flat, unravel = pytree.ravel(vec) >>> assert flat.shape == (20,) >>> vec2 = unravel(flat) >>> assert (vec == vec2).all()
Note that this function requires the optree package to be installed.