Import tensorflow

This commit is contained in:
2026-02-15 21:45:42 -08:00
parent f3e8b90764
commit c530630153
20524 changed files with 9017694 additions and 25 deletions
@@ -0,0 +1,222 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=all
import builtins
import enum
import sys
from collections.abc import Callable, Collection, Iterable, Iterator
from types import MappingProxyType
from typing import Any, ClassVar, Final, final
from typing_extensions import Self # Python 3.11+
from optree.typing import (
FlattenFunc,
MetaData,
PyTree,
PyTreeAccessor,
PyTreeEntry,
T,
U,
UnflattenFunc,
)
# Set if the type allows subclassing (see CPython's Include/object.h)
Py_TPFLAGS_BASETYPE: Final[int] # (1UL << 10)
# Meta-information during build-time
BUILDTIME_METADATA: Final[MappingProxyType[str, Any]]
PY_VERSION: Final[str]
PY_VERSION_HEX: Final[int]
if sys.implementation.name == 'pypy': # noqa: PYI002
PYPY_VERSION: Final[str]
PYPY_VERSION_NUM: Final[int]
PYPY_VERSION_HEX: Final[int]
Py_DEBUG: Final[bool]
Py_GIL_DISABLED: Final[bool]
PYBIND11_VERSION_HEX: Final[int]
PYBIND11_INTERNALS_VERSION: Final[int]
PYBIND11_HAS_NATIVE_ENUM: Final[bool]
PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT: Final[bool]
PYBIND11_HAS_SUBINTERPRETER_SUPPORT: Final[bool]
GLIBCXX_USE_CXX11_ABI: Final[bool]
@final
class InternalError(SystemError): ...
@final
class PyTreeKind(enum.IntEnum):
CUSTOM = 0 # a custom type
LEAF = enum.auto() # an opaque leaf node
NONE = enum.auto() # None
TUPLE = enum.auto() # a tuple
LIST = enum.auto() # a list
DICT = enum.auto() # a dict
NAMEDTUPLE = enum.auto() # a collections.namedtuple
ORDEREDDICT = enum.auto() # a collections.OrderedDict
DEFAULTDICT = enum.auto() # a collections.defaultdict
DEQUE = enum.auto() # a collections.deque
STRUCTSEQUENCE = enum.auto() # a PyStructSequence
NUM_KINDS: ClassVar[int]
MAX_RECURSION_DEPTH: Final[int]
@final
class PyTreeSpec:
num_nodes: int
num_leaves: int
num_children: int
none_is_leaf: bool
namespace: str
type: builtins.type | None
kind: PyTreeKind
def unflatten(self, leaves: Iterable[T], /) -> PyTree[T]: ...
def flatten_up_to(self, tree: PyTree[T], /) -> list[PyTree[T]]: ...
def broadcast_to_common_suffix(self, other: Self, /) -> Self: ...
def transform(
self,
/,
f_node: Callable[[Self], Self] | None = None,
f_leaf: Callable[[Self], Self] | None = None,
) -> Self: ...
def compose(self, inner: Self, /) -> Self: ...
def traverse(
self,
leaves: Iterable[T],
/,
f_node: Callable[[Collection[U]], U] | None = None,
f_leaf: Callable[[T], U] | None = None,
) -> U: ...
def walk(
self,
leaves: Iterable[T],
/,
f_node: Callable[[builtins.type, MetaData, tuple[U, ...]], U] | None = None,
f_leaf: Callable[[T], U] | None = None,
) -> U: ...
def paths(self, /) -> list[tuple[Any, ...]]: ...
def accessors(self, /) -> list[PyTreeAccessor]: ...
def entries(self, /) -> list[Any]: ...
def entry(self, index: int, /) -> Any: ...
def children(self, /) -> list[Self]: ...
def child(self, index: int, /) -> Self: ...
def one_level(self, /) -> Self | None: ...
def is_leaf(self, /, *, strict: bool = True) -> bool: ...
def is_one_level(self, /) -> bool: ...
def is_prefix(self, other: Self, /, *, strict: bool = False) -> bool: ...
def is_suffix(self, other: Self, /, *, strict: bool = False) -> bool: ...
def __eq__(self, other: object, /) -> bool: ...
def __ne__(self, other: object, /) -> bool: ...
def __lt__(self, other: object, /) -> bool: ...
def __le__(self, other: object, /) -> bool: ...
def __gt__(self, other: object, /) -> bool: ...
def __ge__(self, other: object, /) -> bool: ...
def __hash__(self, /) -> int: ...
def __len__(self, /) -> int: ...
@final
class PyTreeIter(Iterator[T]):
def __init__(
self,
tree: PyTree[T],
/,
leaf_predicate: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = '',
) -> None: ...
def __iter__(self, /) -> Self: ...
def __next__(self, /) -> T: ...
# Functions
def flatten(
tree: PyTree[T],
/,
leaf_predicate: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = '',
) -> tuple[list[T], PyTreeSpec]: ...
def flatten_with_path(
tree: PyTree[T],
/,
leaf_predicate: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = '',
) -> tuple[list[tuple[Any, ...]], list[T], PyTreeSpec]: ...
# Constructors
def make_leaf(
none_is_leaf: bool = False,
namespace: str = '', # unused
) -> PyTreeSpec: ...
def make_none(
none_is_leaf: bool = False,
namespace: str = '', # unused
) -> PyTreeSpec: ...
def make_from_collection(
collection: Collection[PyTreeSpec],
/,
none_is_leaf: bool = False,
namespace: str = '',
) -> PyTreeSpec: ...
# Utility functions
def is_leaf(
obj: T,
/,
leaf_predicate: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = '',
) -> bool: ...
def all_leaves(
iterable: Iterable[T],
/,
leaf_predicate: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = '',
) -> bool: ...
def is_namedtuple(obj: object | type, /) -> bool: ...
def is_namedtuple_instance(obj: object, /) -> bool: ...
def is_namedtuple_class(cls: type, /) -> bool: ...
def namedtuple_fields(obj: tuple | type[tuple], /) -> tuple[str, ...]: ...
def is_structseq(obj: object | type, /) -> bool: ...
def is_structseq_instance(obj: object, /) -> bool: ...
def is_structseq_class(cls: type, /) -> bool: ...
def structseq_fields(obj: tuple | type[tuple], /) -> tuple[str, ...]: ...
# Registration functions
def register_node(
cls: type[Collection[T]],
/,
flatten_func: FlattenFunc[T],
unflatten_func: UnflattenFunc[T],
path_entry_type: type[PyTreeEntry],
namespace: str = '',
) -> None: ...
def unregister_node(
cls: type,
/,
namespace: str = '',
) -> None: ...
def is_dict_insertion_ordered(
namespace: str = '',
inherit_global_namespace: bool = True,
) -> bool: ...
def set_dict_insertion_ordered(
mode: bool,
/,
namespace: str = '',
) -> None: ...
@@ -0,0 +1,237 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""OpTree: Optimized PyTree Utilities."""
from optree import accessors, dataclasses, functools, integrations, pytree, treespec, typing
from optree.accessors import (
AutoEntry,
DataclassEntry,
FlattenedEntry,
GetAttrEntry,
GetItemEntry,
MappingEntry,
NamedTupleEntry,
PyTreeAccessor,
PyTreeEntry,
SequenceEntry,
StructSequenceEntry,
)
from optree.ops import (
MAX_RECURSION_DEPTH,
NONE_IS_LEAF,
NONE_IS_NODE,
all_leaves,
broadcast_common,
broadcast_prefix,
prefix_errors,
tree_accessors,
tree_all,
tree_any,
tree_broadcast_common,
tree_broadcast_map,
tree_broadcast_map_with_accessor,
tree_broadcast_map_with_path,
tree_broadcast_prefix,
tree_flatten,
tree_flatten_one_level,
tree_flatten_with_accessor,
tree_flatten_with_path,
tree_is_leaf,
tree_iter,
tree_leaves,
tree_map,
tree_map_,
tree_map_with_accessor,
tree_map_with_accessor_,
tree_map_with_path,
tree_map_with_path_,
tree_max,
tree_min,
tree_partition,
tree_paths,
tree_reduce,
tree_replace_nones,
tree_structure,
tree_sum,
tree_transpose,
tree_transpose_map,
tree_transpose_map_with_accessor,
tree_transpose_map_with_path,
tree_unflatten,
treespec_accessors,
treespec_child,
treespec_children,
treespec_defaultdict,
treespec_deque,
treespec_dict,
treespec_entries,
treespec_entry,
treespec_from_collection,
treespec_is_leaf,
treespec_is_one_level,
treespec_is_prefix,
treespec_is_strict_leaf,
treespec_is_suffix,
treespec_leaf,
treespec_list,
treespec_namedtuple,
treespec_none,
treespec_one_level,
treespec_ordereddict,
treespec_paths,
treespec_structseq,
treespec_transform,
treespec_tuple,
)
from optree.registry import (
dict_insertion_ordered,
register_pytree_node,
register_pytree_node_class,
unregister_pytree_node,
)
from optree.typing import (
CustomTreeNode,
FlattenFunc,
PyTree,
PyTreeDef,
PyTreeKind,
PyTreeSpec,
PyTreeTypeVar,
UnflattenFunc,
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
namedtuple_fields,
structseq_fields,
)
from optree.version import __version__ as __version__ # pylint: disable=useless-import-alias
__all__ = [
# Tree operations
'MAX_RECURSION_DEPTH',
'NONE_IS_NODE',
'NONE_IS_LEAF',
'tree_flatten',
'tree_flatten_with_path',
'tree_flatten_with_accessor',
'tree_unflatten',
'tree_iter',
'tree_leaves',
'tree_structure',
'tree_paths',
'tree_accessors',
'tree_is_leaf',
'all_leaves',
'tree_map',
'tree_map_',
'tree_map_with_path',
'tree_map_with_path_',
'tree_map_with_accessor',
'tree_map_with_accessor_',
'tree_replace_nones',
'tree_partition',
'tree_transpose',
'tree_transpose_map',
'tree_transpose_map_with_path',
'tree_transpose_map_with_accessor',
'tree_broadcast_prefix',
'broadcast_prefix',
'tree_broadcast_common',
'broadcast_common',
'tree_broadcast_map',
'tree_broadcast_map_with_path',
'tree_broadcast_map_with_accessor',
'tree_reduce',
'tree_sum',
'tree_max',
'tree_min',
'tree_all',
'tree_any',
'tree_flatten_one_level',
'prefix_errors',
'treespec_paths',
'treespec_accessors',
'treespec_entries',
'treespec_entry',
'treespec_children',
'treespec_child',
'treespec_one_level',
'treespec_transform',
'treespec_is_leaf',
'treespec_is_strict_leaf',
'treespec_is_one_level',
'treespec_is_prefix',
'treespec_is_suffix',
'treespec_leaf',
'treespec_none',
'treespec_tuple',
'treespec_list',
'treespec_dict',
'treespec_namedtuple',
'treespec_ordereddict',
'treespec_defaultdict',
'treespec_deque',
'treespec_structseq',
'treespec_from_collection',
# Accessor
'PyTreeEntry',
'GetAttrEntry',
'GetItemEntry',
'FlattenedEntry',
'AutoEntry',
'SequenceEntry',
'MappingEntry',
'NamedTupleEntry',
'StructSequenceEntry',
'DataclassEntry',
'PyTreeAccessor',
# Registry
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'dict_insertion_ordered',
# Typing
'PyTreeSpec',
'PyTreeDef',
'PyTreeKind',
'PyTree',
'PyTreeTypeVar',
'CustomTreeNode',
'FlattenFunc',
'UnflattenFunc',
'is_namedtuple',
'is_namedtuple_class',
'is_namedtuple_instance',
'namedtuple_fields',
'is_structseq',
'is_structseq_class',
'is_structseq_instance',
'structseq_fields',
]
MAX_RECURSION_DEPTH: int = MAX_RECURSION_DEPTH
"""Maximum recursion depth for pytree traversal.
This limit prevents infinite recursion from causing an overflow of the C stack
and crashing Python.
"""
NONE_IS_NODE: bool = NONE_IS_NODE # literal constant
"""Literal constant that treats :data:`None` as a pytree non-leaf node."""
NONE_IS_LEAF: bool = NONE_IS_LEAF # literal constant
"""Literal constant that treats :data:`None` as a pytree leaf node."""
@@ -0,0 +1,443 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Access support for pytrees."""
from __future__ import annotations
import dataclasses
import sys
from collections.abc import Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, overload
from typing_extensions import Self # Python 3.11+
import optree._C as _C
from optree._C import PyTreeKind
if TYPE_CHECKING:
import builtins
from optree.typing import NamedTuple, StructSequence
__all__ = [
'PyTreeEntry',
'GetItemEntry',
'GetAttrEntry',
'FlattenedEntry',
'AutoEntry',
'SequenceEntry',
'MappingEntry',
'NamedTupleEntry',
'StructSequenceEntry',
'DataclassEntry',
'PyTreeAccessor',
]
SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} # Python 3.10+
@dataclasses.dataclass(init=True, repr=False, eq=False, frozen=True, **SLOTS)
class PyTreeEntry:
"""Base class for path entries."""
entry: Any
type: builtins.type
kind: PyTreeKind
def __post_init__(self, /) -> None:
"""Post-initialize the path entry."""
if self.kind == PyTreeKind.LEAF:
raise ValueError('Cannot create a leaf path entry.')
if self.kind == PyTreeKind.NONE:
raise ValueError('Cannot create a path entry for None.')
def __call__(self, obj: Any, /) -> Any:
"""Get the child object."""
try:
return obj[self.entry] # should be overridden
except TypeError as ex:
raise TypeError(
f'{self.__class__!r} cannot access through {obj!r} via entry {self.entry!r}',
) from ex
def __add__(self, other: object, /) -> PyTreeAccessor:
"""Join the path entry with another path entry or accessor."""
if isinstance(other, PyTreeEntry):
return PyTreeAccessor((self, other))
if isinstance(other, PyTreeAccessor):
return PyTreeAccessor((self, *other))
return NotImplemented
def __eq__(self, other: object, /) -> bool:
"""Check if the path entries are equal."""
return isinstance(other, PyTreeEntry) and (
(
self.entry,
self.type,
self.kind,
self.__class__.__call__.__code__.co_code,
self.__class__.codify.__code__.co_code,
)
== (
other.entry,
other.type,
other.kind,
other.__class__.__call__.__code__.co_code,
other.__class__.codify.__code__.co_code,
)
)
def __hash__(self, /) -> int:
"""Get the hash of the path entry."""
return hash(
(
self.entry,
self.type,
self.kind,
self.__class__.__call__.__code__.co_code,
self.__class__.codify.__code__.co_code,
),
)
def __repr__(self, /) -> str:
"""Get the representation of the path entry."""
return f'{self.__class__.__name__}(entry={self.entry!r}, type={self.type!r})'
def codify(self, /, node: str = '') -> str:
"""Generate code for accessing the path entry."""
return f'{node}[<flat index {self.entry!r}>]' # should be overridden
del SLOTS
_T = TypeVar('_T')
_T_co = TypeVar('_T_co', covariant=True)
_KT_co = TypeVar('_KT_co', covariant=True)
_VT_co = TypeVar('_VT_co', covariant=True)
class AutoEntry(PyTreeEntry):
"""A generic path entry class that determines the entry type on creation automatically."""
__slots__: ClassVar[tuple[()]] = ()
def __new__( # type: ignore[misc]
cls,
/,
entry: Any,
type: builtins.type, # pylint: disable=redefined-builtin
kind: PyTreeKind,
) -> PyTreeEntry:
"""Create a new path entry."""
# pylint: disable-next=import-outside-toplevel
from optree.typing import is_namedtuple_class, is_structseq_class
if cls is not AutoEntry:
# Use the subclass type if the type is explicitly specified
return super().__new__(cls)
if kind != PyTreeKind.CUSTOM:
raise ValueError(f'Cannot create an automatic path entry for PyTreeKind {kind!r}.')
# Dispatch the path entry type based on the node type
path_entry_type: builtins.type[PyTreeEntry]
if is_structseq_class(type):
path_entry_type = StructSequenceEntry
elif is_namedtuple_class(type):
path_entry_type = NamedTupleEntry
elif dataclasses.is_dataclass(type):
path_entry_type = DataclassEntry
elif issubclass(type, Mapping):
path_entry_type = MappingEntry
elif issubclass(type, Sequence):
path_entry_type = SequenceEntry
else:
path_entry_type = FlattenedEntry
if not issubclass(path_entry_type, AutoEntry):
# The __init__() method will not be called if the returned instance is not a subtype of
# AutoEntry. We should return an initialized instance. Return a fully-initialized
# instance of the dispatched type.
return path_entry_type(entry, type, kind)
# The __init__() method will be called if the returned instance is a subtype of AutoEntry.
# We should return an uninitialized instance. The __init__() method will initialize it.
# But we will never reach here because the dispatched type is never a subtype of AutoEntry.
raise NotImplementedError('Unreachable code.')
class GetItemEntry(PyTreeEntry):
"""A generic path entry class for nodes that access their children by :meth:`__getitem__`."""
__slots__: ClassVar[tuple[()]] = ()
def __call__(self, obj: Any, /) -> Any:
"""Get the child object."""
return obj[self.entry]
def codify(self, /, node: str = '') -> str:
"""Generate code for accessing the path entry."""
return f'{node}[{self.entry!r}]'
class GetAttrEntry(PyTreeEntry):
"""A generic path entry class for nodes that access their children by :meth:`__getattr__`."""
__slots__: ClassVar[tuple[()]] = ()
entry: str
@property
def name(self, /) -> str:
"""Get the attribute name."""
return self.entry
def __call__(self, obj: Any, /) -> Any:
"""Get the child object."""
return getattr(obj, self.name)
def codify(self, /, node: str = '') -> str:
"""Generate code for accessing the path entry."""
return f'{node}.{self.name}'
class FlattenedEntry(PyTreeEntry): # pylint: disable=too-few-public-methods
"""A fallback path entry class for flattened objects."""
__slots__: ClassVar[tuple[()]] = ()
class SequenceEntry(GetItemEntry, Generic[_T_co]):
"""A path entry class for sequences."""
__slots__: ClassVar[tuple[()]] = ()
entry: int
type: builtins.type[Sequence[_T_co]]
@property
def index(self, /) -> int:
"""Get the index."""
return self.entry
def __call__(self, obj: Sequence[_T_co], /) -> _T_co:
"""Get the child object."""
return obj[self.index]
def __repr__(self, /) -> str:
"""Get the representation of the path entry."""
return f'{self.__class__.__name__}(index={self.index!r}, type={self.type!r})'
class MappingEntry(GetItemEntry, Generic[_KT_co, _VT_co]):
"""A path entry class for mappings."""
__slots__: ClassVar[tuple[()]] = ()
entry: _KT_co
type: builtins.type[Mapping[_KT_co, _VT_co]]
@property
def key(self, /) -> _KT_co:
"""Get the key."""
return self.entry
def __call__(self, obj: Mapping[_KT_co, _VT_co], /) -> _VT_co:
"""Get the child object."""
return obj[self.key]
def __repr__(self, /) -> str:
"""Get the representation of the path entry."""
return f'{self.__class__.__name__}(key={self.key!r}, type={self.type!r})'
class NamedTupleEntry(SequenceEntry[_T]):
"""A path entry class for namedtuple objects."""
__slots__: ClassVar[tuple[()]] = ()
entry: int
type: builtins.type[NamedTuple[_T]] # type: ignore[type-arg]
kind: Literal[PyTreeKind.NAMEDTUPLE]
@property
def fields(self, /) -> tuple[str, ...]:
"""Get the field names."""
from optree.typing import namedtuple_fields # pylint: disable=import-outside-toplevel
return namedtuple_fields(self.type)
@property
def field(self, /) -> str:
"""Get the field name."""
return self.fields[self.entry]
def __repr__(self, /) -> str:
"""Get the representation of the path entry."""
return f'{self.__class__.__name__}(field={self.field!r}, type={self.type!r})'
def codify(self, /, node: str = '') -> str:
"""Generate code for accessing the path entry."""
return f'{node}.{self.field}'
class StructSequenceEntry(SequenceEntry[_T]):
"""A path entry class for PyStructSequence objects."""
__slots__: ClassVar[tuple[()]] = ()
entry: int
type: builtins.type[StructSequence[_T]]
kind: Literal[PyTreeKind.STRUCTSEQUENCE]
@property
def fields(self, /) -> tuple[str, ...]:
"""Get the field names."""
from optree.typing import structseq_fields # pylint: disable=import-outside-toplevel
return structseq_fields(self.type)
@property
def field(self, /) -> str:
"""Get the field name."""
return self.fields[self.entry]
def __repr__(self, /) -> str:
"""Get the representation of the path entry."""
return f'{self.__class__.__name__}(field={self.field!r}, type={self.type!r})'
def codify(self, /, node: str = '') -> str:
"""Generate code for accessing the path entry."""
return f'{node}.{self.field}'
class DataclassEntry(GetAttrEntry):
"""A path entry class for dataclasses."""
__slots__: ClassVar[tuple[()]] = ()
entry: str | int # type: ignore[assignment]
@property
def fields(self, /) -> tuple[str, ...]: # pragma: no cover
"""Get all field names."""
return tuple(f.name for f in dataclasses.fields(self.type))
@property
def init_fields(self, /) -> tuple[str, ...]:
"""Get the init field names."""
return tuple(f.name for f in dataclasses.fields(self.type) if f.init)
@property
def field(self, /) -> str:
"""Get the field name."""
if isinstance(self.entry, int):
return self.init_fields[self.entry]
return self.entry
@property
def name(self, /) -> str:
"""Get the attribute name."""
return self.field
def __repr__(self, /) -> str:
"""Get the representation of the path entry."""
return f'{self.__class__.__name__}(field={self.field!r}, type={self.type!r})'
class PyTreeAccessor(tuple[PyTreeEntry, ...]):
"""A path class for PyTrees."""
__slots__: ClassVar[tuple[()]] = ()
@property
def path(self, /) -> tuple[Any, ...]:
"""Get the path of the accessor."""
return tuple(e.entry for e in self)
def __new__(cls, /, path: Iterable[PyTreeEntry] = ()) -> Self:
"""Create a new accessor instance."""
if not isinstance(path, (list, tuple)):
path = tuple(path)
if not all(isinstance(p, PyTreeEntry) for p in path):
raise TypeError(f'Expected a path of PyTreeEntry, got {path!r}.')
return super().__new__(cls, path)
def __call__(self, obj: Any, /) -> Any:
"""Get the child object."""
for entry in self:
obj = entry(obj)
return obj
@overload # type: ignore[override]
def __getitem__(self, index: int, /) -> PyTreeEntry: ...
@overload
def __getitem__(self, index: slice, /) -> Self: ...
def __getitem__(self, index: int | slice, /) -> PyTreeEntry | Self:
"""Get the child path entry or an accessor for a subpath."""
if isinstance(index, slice):
return self.__class__(super().__getitem__(index))
return super().__getitem__(index)
def __add__(self, other: object, /) -> Self:
"""Join the accessor with another path entry or accessor."""
if isinstance(other, PyTreeEntry):
return self.__class__((*self, other))
if isinstance(other, PyTreeAccessor):
return self.__class__((*self, *other))
return NotImplemented
def __mul__(self, value: int, /) -> Self: # type: ignore[override]
"""Repeat the accessor."""
return self.__class__(super().__mul__(value))
def __rmul__(self, value: int, /) -> Self: # type: ignore[override]
"""Repeat the accessor."""
return self.__class__(super().__rmul__(value))
def __eq__(self, other: object, /) -> bool:
"""Check if the accessors are equal."""
return isinstance(other, PyTreeAccessor) and super().__eq__(other)
def __hash__(self, /) -> int:
"""Get the hash of the accessor."""
return super().__hash__()
def __repr__(self, /) -> str:
"""Get the representation of the accessor."""
return f'{self.__class__.__name__}({self.codify()}, {super().__repr__()})'
def codify(self, /, root: str = '*') -> str:
"""Generate code for accessing the path."""
string = root
for entry in self:
string = entry.codify(string)
return string
# These classes are used internally in the C++ side for accessor APIs
_name, _cls = '', object
for _name in __all__:
_cls = globals()[_name]
if not isinstance(_cls, type): # pragma: no cover
raise TypeError(f'Expected a class, got {_cls!r}.')
_cls.__module__ = 'optree'
setattr(_C, _name, _cls)
del _name, _cls
@@ -0,0 +1,509 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""PyTree integration with :mod:`dataclasses`.
This module implements PyTree integration with :mod:`dataclasses` by redefining the :func:`field`,
:func:`dataclass`, and :func:`make_dataclass` functions. Other APIs are re-exported from the
original :mod:`dataclasses` module.
The PyTree integration allows dataclasses to be flattened and unflattened recursively. The fields
are stored in a special attribute named ``__optree_dataclass_fields__`` in the dataclass.
>>> import math
... import optree
...
>>> @optree.dataclasses.dataclass(namespace='my_module')
... class Point:
... x: float
... y: float
... z: float = 0.0
... norm: float = optree.dataclasses.field(init=False, pytree_node=False)
...
... def __post_init__(self) -> None:
... self.norm = math.hypot(self.x, self.y, self.z)
...
>>> point = Point(2.0, 6.0, 3.0)
>>> point
Point(x=2.0, y=6.0, z=3.0, norm=7.0)
>>> # Flatten without specifying the namespace
>>> optree.tree_flatten(point) # `Point`s are leaf nodes
([Point(x=2.0, y=6.0, z=3.0, norm=7.0)], PyTreeSpec(*))
>>> # Flatten with the namespace
>>> accessors, leaves, treespec = optree.tree_flatten_with_accessor(point, namespace='my_module')
>>> accessors, leaves, treespec # doctest: +IGNORE_WHITESPACE,ELLIPSIS
(
[
PyTreeAccessor(*.x, (DataclassEntry(field='x', type=<class '...Point'>),)),
PyTreeAccessor(*.y, (DataclassEntry(field='y', type=<class '...Point'>),)),
PyTreeAccessor(*.z, (DataclassEntry(field='z', type=<class '...Point'>),))
],
[2.0, 6.0, 3.0],
PyTreeSpec(CustomTreeNode(Point[()], [*, *, *]), namespace='my_module')
)
>>> point == optree.tree_unflatten(treespec, leaves)
True
"""
# pylint: disable=too-many-arguments
from __future__ import annotations
import contextlib
import dataclasses
import functools
import inspect
import sys
import types
from dataclasses import * # noqa: F401,F403,RUF100 # pylint: disable=wildcard-import,unused-wildcard-import
from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, TypeVar, overload
from typing_extensions import dataclass_transform # Python 3.11+
if TYPE_CHECKING:
from collections.abc import Iterable
# Redefine `field`, `dataclasses`, and `make_dataclasses`.
# The remaining APIs are re-exported from the original package.
__all__ = [*dataclasses.__all__]
_FIELDS = '__optree_dataclass_fields__'
_PYTREE_NODE_DEFAULT: bool = True
_T = TypeVar('_T')
_U = TypeVar('_U')
_TypeT = TypeVar('_TypeT', bound=type)
@overload # type: ignore[no-redef]
def field(
*,
default: _T,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
hash: bool | None = None, # pylint: disable=redefined-builtin
compare: bool = True,
metadata: dict[Any, Any] | None = None,
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+
doc: str | None = None, # Python 3.14+
pytree_node: bool | None = None,
) -> _T: ...
@overload
def field(
*,
default_factory: Callable[[], _T],
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
hash: bool | None = None, # pylint: disable=redefined-builtin
compare: bool = True,
metadata: dict[Any, Any] | None = None,
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+
doc: str | None = None, # Python 3.14+
pytree_node: bool | None = None,
) -> _T: ...
@overload
def field(
*,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
hash: bool | None = None, # pylint: disable=redefined-builtin
compare: bool = True,
metadata: dict[Any, Any] | None = None,
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+
doc: str | None = None, # Python 3.14+
pytree_node: bool | None = None,
) -> Any: ...
def field( # noqa: D417 # pylint: disable=function-redefined
*,
default: Any = dataclasses.MISSING,
default_factory: Any = dataclasses.MISSING,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
hash: bool | None = None, # pylint: disable=redefined-builtin
compare: bool = True,
metadata: dict[Any, Any] | None = None,
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+
doc: str | None = None, # Python 3.14+
pytree_node: bool | None = None,
) -> Any:
"""Field factory for :func:`dataclass`.
This factory function is used to define the fields in a dataclass. It is similar to the field
factory :func:`dataclasses.field`, but with an additional ``pytree_node`` parameter. If
``pytree_node`` is :data:`True` (default), the field will be considered a child node in the
PyTree structure which can be recursively flattened and unflattened. Otherwise, the field will
be considered as PyTree metadata.
Setting ``pytree_node`` in the field factory is equivalent to setting a key ``'pytree_node'`` in
``metadata`` in the original field factory. The ``pytree_node`` value can be accessed using
``field.metadata['pytree_node']``. If ``pytree_node`` is :data:`None`, the value
``metadata.get('pytree_node', True)`` will be used.
.. note::
If a field is considered a child node, it must be included in the argument list of the
:meth:`__init__` method, i.e., passes ``init=True`` in the field factory.
Args:
pytree_node (bool or None, optional): Whether the field is a PyTree node.
**kwargs (optional): Optional keyword arguments passed to :func:`dataclasses.field`.
Returns:
dataclasses.Field: The field defined using the provided arguments with
``field.metadata['pytree_node']`` set.
"""
metadata = (metadata or {}).copy()
if pytree_node is None:
pytree_node = metadata.get('pytree_node', _PYTREE_NODE_DEFAULT)
metadata['pytree_node'] = pytree_node
kwargs = {
'default': default,
'default_factory': default_factory,
'init': init,
'repr': repr,
'hash': hash,
'compare': compare,
'metadata': metadata,
}
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
kwargs['kw_only'] = kw_only
elif kw_only is not dataclasses.MISSING: # pragma: <3.10 cover
raise TypeError("field() got an unexpected keyword argument 'kw_only'")
if sys.version_info >= (3, 14): # pragma: >=3.14 cover
kwargs['doc'] = doc
elif doc is not None: # pragma: <3.14 cover
raise TypeError("field() got an unexpected keyword argument 'doc'")
if not init and pytree_node:
raise TypeError(
'`pytree_node=True` is not allowed for non-init fields. '
f'Please explicitly set `{__name__}.field(init=False, pytree_node=False)`.',
)
return dataclasses.field(**kwargs) # pylint: disable=invalid-field-call
@overload # type: ignore[no-redef]
def dataclass(
*,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
match_args: bool = True, # Python 3.10+
kw_only: bool = False, # Python 3.10+
slots: bool = False, # Python 3.10+
weakref_slot: bool = False, # Python 3.11+
namespace: str,
) -> Callable[[_TypeT], _TypeT]: ...
@overload
def dataclass(
cls: _TypeT,
/,
*,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
match_args: bool = True, # Python 3.10+
kw_only: bool = False, # Python 3.10+
slots: bool = False, # Python 3.10+
weakref_slot: bool = False, # Python 3.11+
namespace: str,
) -> _TypeT: ...
@dataclass_transform(field_specifiers=(field,))
def dataclass( # noqa: C901,D417 # pylint: disable=function-redefined,too-many-locals,too-many-branches
cls: _TypeT | None = None,
/,
*,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
match_args: bool = True, # Python 3.10+
kw_only: bool = False, # Python 3.10+
slots: bool = False, # Python 3.10+
weakref_slot: bool = False, # Python 3.11+
namespace: str,
) -> _TypeT | Callable[[_TypeT], _TypeT]:
"""Dataclass decorator with PyTree integration.
Args:
cls (type or None, optional): The class to decorate. If :data:`None`, return a decorator.
namespace (str): The registry namespace used for the PyTree registration.
**kwargs (optional): Optional keyword arguments passed to :func:`dataclasses.dataclass`.
Returns:
type or callable: The decorated class with PyTree integration or decorator function.
"""
# pylint: disable-next=import-outside-toplevel
from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE
kwargs = {
'init': init,
'repr': repr,
'eq': eq,
'order': order,
'unsafe_hash': unsafe_hash,
'frozen': frozen,
}
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
kwargs['match_args'] = match_args
kwargs['kw_only'] = kw_only
kwargs['slots'] = slots
elif match_args is not True: # pragma: <3.10 cover
raise TypeError("dataclass() got an unexpected keyword argument 'match_args'")
elif kw_only is not False: # pragma: <3.10 cover
raise TypeError("dataclass() got an unexpected keyword argument 'kw_only'")
elif slots is not False: # pragma: <3.10 cover
raise TypeError("dataclass() got an unexpected keyword argument 'slots'")
if sys.version_info >= (3, 11): # pragma: >=3.11 cover
kwargs['weakref_slot'] = weakref_slot
elif weakref_slot is not False: # pragma: <3.11 cover
raise TypeError("dataclass() got an unexpected keyword argument 'weakref_slot'")
if cls is None:
def decorator(cls: _TypeT) -> _TypeT:
return dataclass(cls, namespace=namespace, **kwargs) # type: ignore[call-overload]
return decorator
if not inspect.isclass(cls):
raise TypeError(f'@{__name__}.dataclass() can only be used with classes, not {cls!r}.')
if _FIELDS in cls.__dict__:
raise TypeError(
f'@{__name__}.dataclass() cannot be applied to {cls.__name__} more than once.',
)
if namespace is not GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
namespace = GLOBAL_NAMESPACE
cls = dataclasses.dataclass(cls, **kwargs) # type: ignore[assignment]
children_fields = {}
metadata_fields = {}
for f in dataclasses.fields(cls):
if f.metadata.get('pytree_node', _PYTREE_NODE_DEFAULT):
if not f.init:
raise TypeError(
f'PyTree node field {f.name!r} must be included in `__init__()`. '
f'Or you can explicitly set `{__name__}.field(init=False, pytree_node=False)`.',
)
children_fields[f.name] = f
elif f.init:
metadata_fields[f.name] = f
children_field_names = tuple(children_fields)
children_fields = types.MappingProxyType(children_fields)
metadata_fields = types.MappingProxyType(metadata_fields)
setattr(cls, _FIELDS, (children_fields, metadata_fields))
def flatten_func(
obj: _T,
/,
) -> tuple[
tuple[_U, ...],
tuple[tuple[str, Any], ...],
tuple[str, ...],
]:
children = tuple(getattr(obj, name) for name in children_field_names)
metadata = tuple((name, getattr(obj, name)) for name in metadata_fields)
return children, metadata, children_field_names
# pylint: disable-next=line-too-long
def unflatten_func(metadata: tuple[tuple[str, Any], ...], children: tuple[_U, ...], /) -> _T: # type: ignore[type-var]
kwargs = dict(zip(children_field_names, children))
kwargs.update(metadata)
return cls(**kwargs)
from optree.accessors import DataclassEntry # pylint: disable=import-outside-toplevel
from optree.registry import register_pytree_node # pylint: disable=import-outside-toplevel
return register_pytree_node( # type: ignore[return-value]
cls,
flatten_func,
unflatten_func, # type: ignore[arg-type]
path_entry_type=DataclassEntry,
namespace=namespace,
)
class _DataclassDecorator(Protocol[_TypeT]): # pylint: disable=too-few-public-methods
def __call__( # pylint: disable=arguments-differ
self,
cls: _TypeT,
/,
*,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
match_args: bool = True,
kw_only: bool = False,
slots: bool = False,
weakref_slot: bool = False,
) -> _TypeT:
raise NotImplementedError
# pylint: disable-next=function-redefined,too-many-locals,too-many-branches
def make_dataclass( # type: ignore[no-redef] # noqa: C901,D417
cls_name: str,
# pylint: disable-next=redefined-outer-name
fields: Iterable[str | tuple[str, Any] | tuple[str, Any, Any]],
*,
bases: tuple[type, ...] = (),
ns: dict[str, Any] | None = None, # redirect to `namespace` to `dataclasses.make_dataclass()`
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
match_args: bool = True, # Python 3.10+
kw_only: bool = False, # Python 3.10+
slots: bool = False, # Python 3.10+
weakref_slot: bool = False, # Python 3.11+
module: str | None = None, # Python 3.12+
decorator: _DataclassDecorator[_TypeT] = dataclasses.dataclass, # type: ignore[assignment] # Python 3.14+
namespace: str, # the PyTree registration namespace
) -> _TypeT:
"""Make a new dynamically created dataclass with PyTree integration.
The dataclass name will be ``cls_name``. ``fields`` is an iterable of either (name), (name, type),
or (name, type, Field) objects. If type is omitted, use the string :data:`typing.Any`. Field
objects are created by the equivalent of calling :func:`field` (name, type [, Field-info]).
The ``namespace`` parameter is the PyTree registration namespace which should be a string. The
``namespace`` in the original :func:`dataclasses.make_dataclass` function is renamed to ``ns``
to avoid conflicts.
The remaining parameters are passed to :func:`dataclasses.make_dataclass`.
See :func:`dataclasses.make_dataclass` for more information.
Args:
cls_name: The name of the dataclass.
fields (Iterable[str | tuple[str, Any] | tuple[str, Any, Any]]): An iterable of either
(name), (name, type), or (name, type, Field) objects.
namespace (str): The registry namespace used for the PyTree registration.
ns (dict or None, optional): The namespace used in dynamic type creation.
See :func:`dataclasses.make_dataclass` and the builtin :func:`type` function for more
information.
**kwargs (optional): Optional keyword arguments passed to :func:`dataclasses.make_dataclass`.
Returns:
type: The dynamically created dataclass with PyTree integration.
"""
# pylint: disable-next=import-outside-toplevel
from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE
if isinstance(namespace, dict) or namespace is None: # type: ignore[unreachable]
if ns is GLOBAL_NAMESPACE or isinstance(ns, str): # type: ignore[unreachable]
ns, namespace = namespace, ns
elif ns is None:
raise TypeError("make_dataclass() missing 1 required keyword-only argument: 'ns'")
if namespace is not GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
namespace = GLOBAL_NAMESPACE
dataclass_kwargs = {
'init': init,
'repr': repr,
'eq': eq,
'order': order,
'unsafe_hash': unsafe_hash,
'frozen': frozen,
}
make_dataclass_kwargs = {
'bases': bases,
'namespace': ns,
}
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
dataclass_kwargs['match_args'] = match_args
dataclass_kwargs['kw_only'] = kw_only
dataclass_kwargs['slots'] = slots
elif match_args is not True: # pragma: <3.10 cover
raise TypeError("make_dataclass() got an unexpected keyword argument 'match_args'")
elif kw_only is not False: # pragma: <3.10 cover
raise TypeError("make_dataclass() got an unexpected keyword argument 'kw_only'")
elif slots is not False: # pragma: <3.10 cover
raise TypeError("make_dataclass() got an unexpected keyword argument 'slots'")
if sys.version_info >= (3, 11): # pragma: >=3.11 cover
dataclass_kwargs['weakref_slot'] = weakref_slot
elif weakref_slot is not False: # pragma: <3.11 cover
raise TypeError("make_dataclass() got an unexpected keyword argument 'weakref_slot'")
if sys.version_info >= (3, 12): # pragma: >=3.12 cover
if module is None:
try:
# pylint: disable-next=protected-access
module = sys._getframemodulename(1) or '__main__' # type: ignore[attr-defined]
except AttributeError: # pragma: no cover
with contextlib.suppress(AttributeError, ValueError):
# pylint: disable-next=protected-access
module = sys._getframe(1).f_globals.get('__name__', '__main__')
make_dataclass_kwargs['module'] = module
elif module is not None: # pragma: <3.12 cover
raise TypeError("make_dataclass() got an unexpected keyword argument 'module'")
registered_by_decorator = False
if sys.version_info >= (3, 14): # pragma: >=3.14 cover
if decorator in (dataclasses.dataclass, dataclass):
decorator = functools.partial(dataclass, namespace=namespace)
registered_by_decorator = True
make_dataclass_kwargs['decorator'] = decorator
elif decorator is not dataclasses.dataclass: # pragma: <3.14 cover
raise TypeError("make_dataclass() got an unexpected keyword argument 'decorator'")
cls: _TypeT = dataclasses.make_dataclass( # type: ignore[assignment]
cls_name,
fields=fields,
**dataclass_kwargs, # type: ignore[arg-type]
**make_dataclass_kwargs, # type: ignore[arg-type]
)
if not registered_by_decorator: # pragma: <3.14 cover
dataclass_kwargs.pop('slots', None) # already defined in `make_dataclass()`
dataclass_kwargs.pop('weakref_slot', None) # already used in `make_dataclass()`
cls = dataclass(cls, **dataclass_kwargs, namespace=namespace) # type: ignore[call-overload]
return cls
@@ -0,0 +1,169 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""PyTree integration with :mod:`functools`."""
from __future__ import annotations
import functools
from typing import TYPE_CHECKING, Any, Callable, ClassVar
from typing_extensions import Self # Python 3.11+
from optree import registry
from optree.accessors import GetAttrEntry
from optree.ops import tree_reduce as reduce
from optree.typing import CustomTreeNode, T
if TYPE_CHECKING:
from optree.accessors import PyTreeEntry
__all__ = [
'partial',
'reduce',
]
class _HashablePartialShim:
"""Object that delegates :meth:`__call__`, :meth:`__eq__`, and :meth:`__hash__` to another object."""
__slots__: ClassVar[tuple[str, ...]] = ('args', 'func', 'keywords', 'partial_func')
func: Callable[..., Any]
args: tuple[Any, ...]
keywords: dict[str, Any]
def __init__(self, partial_func: functools.partial, /) -> None:
self.partial_func: functools.partial = partial_func
def __call__(self, /, *args: Any, **kwargs: Any) -> Any:
return self.partial_func(*args, **kwargs)
def __eq__(self, other: object, /) -> bool:
if isinstance(other, _HashablePartialShim):
return self.partial_func == other.partial_func
return self.partial_func == other
def __hash__(self, /) -> int:
return hash(self.partial_func)
def __repr__(self, /) -> str:
return repr(self.partial_func)
# pylint: disable-next=protected-access
@registry.register_pytree_node_class(namespace=registry.__GLOBAL_NAMESPACE)
class partial( # noqa: N801 # pylint: disable=invalid-name,too-few-public-methods
functools.partial,
CustomTreeNode[T],
):
"""A version of :func:`functools.partial` that works in pytrees.
Use it for partial function evaluation in a way that is compatible with transformations,
e.g., ``partial(func, *args, **kwargs)``.
(You need to explicitly opt-in to this behavior because we did not want to give
:func:`functools.partial` different semantics than normal function closures.)
For example, here is a basic usage of :class:`partial` in a manner similar to
:func:`functools.partial`:
>>> import operator
>>> import torch
>>> add_one = partial(operator.add, torch.ones(()))
>>> add_one(torch.tensor([[1, 2], [3, 4]]))
tensor([[2., 3.],
[4., 5.]])
Pytree compatibility means that the resulting partial function can be passed as an argument
within tree-map functions, which is not possible with a standard :func:`functools.partial`
function:
>>> def call_func_on_cuda(f, *args, **kwargs):
... f, args, kwargs = tree_map(lambda t: t.cuda(), (f, args, kwargs))
... return f(*args, **kwargs)
...
>>> # doctest: +SKIP
>>> tree_map(lambda t: t.cuda(), add_one)
optree.functools.partial(<built-in function add>, tensor(1., device='cuda:0'))
>>> call_func_on_cuda(add_one, torch.tensor([[1, 2], [3, 4]]))
tensor([[2., 3.],
[4., 5.]], device='cuda:0')
Passing zero arguments to :class:`partial` effectively wraps the original function, making it a
valid argument in tree-map functions:
>>> # doctest: +SKIP
>>> call_func_on_cuda(partial(torch.add), torch.tensor(1), torch.tensor(2))
tensor(3, device='cuda:0')
Had we passed :func:`operator.add` to ``call_func_on_cuda`` directly, it would have resulted in
a :class:`TypeError` or :class:`AttributeError`.
"""
__slots__: ClassVar[tuple[()]] = ()
func: Callable[..., Any]
args: tuple[T, ...]
keywords: dict[str, T]
TREE_PATH_ENTRY_TYPE: ClassVar[type[PyTreeEntry]] = GetAttrEntry
def __new__(cls, func: Callable[..., Any], /, *args: T, **keywords: T) -> Self:
"""Create a new :class:`partial` instance."""
# In Python 3.10+, if func is itself a functools.partial instance, functools.partial.__new__
# would merge the arguments of this partial instance with the arguments of the func. We box
# func in a class that does not (yet) have a `func` attribute to defeat this optimization,
# since we care exactly which arguments are considered part of the pytree.
if isinstance(func, functools.partial):
original_func = func
func = _HashablePartialShim(original_func)
assert not hasattr(func, 'func'), 'shimmed function should not have a `func` attribute'
out = super().__new__(cls, func, *args, **keywords)
func.func = original_func.func
func.args = original_func.args
func.keywords = original_func.keywords
return out
return super().__new__(cls, func, *args, **keywords)
def __repr__(self, /) -> str:
"""Return a string representation of the :class:`partial` instance."""
args = [repr(self.func)]
args.extend(repr(x) for x in self.args)
args.extend(f'{k}={v!r}' for (k, v) in self.keywords.items())
return f'{self.__class__.__module__}.{self.__class__.__qualname__}({", ".join(args)})'
def __tree_flatten__( # type: ignore[override]
self,
/,
) -> tuple[
tuple[tuple[T, ...], dict[str, T]],
Callable[..., Any],
tuple[str, str],
]:
"""Flatten the :class:`partial` instance to children and metadata."""
return (self.args, self.keywords), self.func, ('args', 'keywords')
@classmethod
def __tree_unflatten__( # type: ignore[override]
cls,
metadata: Callable[..., Any],
children: tuple[tuple[T, ...], dict[str, T]],
/,
) -> Self:
"""Unflatten the children and metadata into a :class:`partial` instance."""
args, keywords = children
return cls(metadata, *args, **keywords)
@@ -0,0 +1,49 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integrations with third-party libraries."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from types import ModuleType
from optree.integrations import jax, numpy, torch
SUBMODULES: frozenset[str] = frozenset({'jax', 'numpy', 'torch'})
def __dir__() -> list[str]:
return [*sorted(SUBMODULES), 'SUBMODULES']
def __getattr__(name: str, /) -> ModuleType:
if name in SUBMODULES:
import importlib # pylint: disable=import-outside-toplevel
import sys # pylint: disable=import-outside-toplevel
module = sys.modules[__name__]
submodule = importlib.import_module(f'{__name__}.{name}') # pragma: no cover
setattr(module, name, submodule) # pragma: no cover
return submodule # pragma: no cover
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
del TYPE_CHECKING
@@ -0,0 +1,293 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This file is modified from:
# https://github.com/google/jax/blob/jax-v0.4.20/jax/_src/flatten_util.py
# ==============================================================================
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integration with JAX."""
# pragma: jax cover file
# pylint: disable=import-error
from __future__ import annotations
import contextlib
import itertools
import warnings
from operator import itemgetter
from types import FunctionType
from typing import Any, Callable
from typing_extensions import TypeAlias # Python 3.10+
import jax.numpy as jnp
from jax import Array, lax
from jax._src import dtypes
from jax.typing import ArrayLike
from optree.ops import tree_flatten, tree_unflatten
from optree.typing import PyTreeSpec, PyTreeTypeVar
from optree.utils import safe_zip, total_order_sorted
__all__ = ['ArrayLikeTree', 'ArrayTree', 'tree_ravel']
# pylint: disable-next=invalid-name
ArrayLikeTree: TypeAlias = PyTreeTypeVar('ArrayLikeTree', ArrayLike) # type: ignore[valid-type]
# pylint: disable-next=invalid-name
ArrayTree: TypeAlias = PyTreeTypeVar('ArrayTree', Array) # type: ignore[valid-type]
# Vendor from https://github.com/google/jax/blob/jax-v0.4.20/jax/_src/util.py
class HashablePartial: # pragma: no cover
"""A hashable version of :class:`functools.partial`."""
func: FunctionType
args: tuple[Any, ...]
kwargs: dict[str, Any]
def __init__(self, func: FunctionType | HashablePartial, /, *args: Any, **kwargs: Any) -> None:
"""Construct a :class:`HashablePartial` instance."""
if not callable(func):
raise TypeError(f'Expected a callable, got {func!r}.')
if isinstance(func, HashablePartial):
self.func = func.func
self.args = func.args + args
self.kwargs = {**func.kwargs, **kwargs}
elif isinstance(func, FunctionType):
self.func = func # type: ignore[assignment]
self.args = args
self.kwargs = kwargs
else:
raise TypeError(f'Expected a function, got {func!r}.')
def __eq__(self, other: object, /) -> bool:
return (
type(other) is HashablePartial # pylint: disable=unidiomatic-typecheck
and self.func.__code__ == other.func.__code__
and (self.args, self.kwargs) == (other.args, other.kwargs)
)
def __hash__(self, /) -> int:
return hash(
(
self.func.__code__,
self.args,
tuple(total_order_sorted(self.kwargs.items(), key=itemgetter(0))),
),
)
def __call__(self, /, *args: Any, **kwargs: Any) -> Any:
kwargs = {**self.kwargs, **kwargs}
return self.func(*self.args, *args, **kwargs)
with contextlib.suppress(ImportError): # pragma: no cover
# pylint: disable-next=ungrouped-imports
from jax._src.util import HashablePartial # type: ignore[no-redef] # noqa: F811,RUF100
def tree_ravel(
tree: ArrayLikeTree,
/,
is_leaf: Callable[[Any], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = '',
) -> tuple[Array, Callable[[Array], ArrayTree]]:
r"""Ravel (flatten) a pytree of arrays down to a 1D array.
>>> tree = {
... 'layer1': {
... 'weight': jnp.arange(0, 6, dtype=jnp.float32).reshape((2, 3)),
... 'bias': jnp.arange(6, 8, dtype=jnp.float32).reshape((2,)),
... },
... 'layer2': {
... 'weight': jnp.arange(8, 10, dtype=jnp.float32).reshape((1, 2)),
... 'bias': jnp.arange(10, 11, dtype=jnp.float32).reshape((1,)),
... },
... }
>>> tree # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([6., 7.], dtype=float32)
},
'layer2': {
'weight': Array([[8., 9.]], dtype=float32),
'bias': Array([10.], dtype=float32)
}
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
Array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32)
>>> unravel_func(flat) # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([6., 7.], dtype=float32)
},
'layer2': {
'weight': Array([[8., 9.]], dtype=float32),
'bias': Array([10.], dtype=float32)
}
}
Args:
tree (pytree): a pytree of arrays and scalars to ravel.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list and :data:`None` will be remain in the result
pytree. (default: :data:`False`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)
Returns:
A pair ``(array, unravel_func)`` where the first element is a 1D array representing the
flattened and concatenated leaf values, with ``dtype`` determined by promoting the
``dtype``\s of leaf values, and the second element is a callable for unflattening a 1D array
of the same length back to a pytree of the same structure as the input ``tree``. If the
input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the
default dtype is returned in the first component of the output.
"""
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat, unravel_flat = _ravel_leaves(leaves)
return flat, HashablePartial(_tree_unravel, treespec, unravel_flat) # type: ignore[arg-type]
ravel_pytree = tree_ravel
def _tree_unravel(
treespec: PyTreeSpec,
unravel_flat: Callable[[Array], list[ArrayLike]],
flat: Array,
/,
) -> ArrayTree:
return tree_unflatten(treespec, unravel_flat(flat))
def _ravel_leaves(
leaves: list[ArrayLike],
/,
) -> tuple[
Array,
Callable[[Array], list[ArrayLike]],
]:
if not leaves:
return (jnp.zeros(0), _unravel_empty)
from_dtypes = tuple(dtypes.dtype(leaf) for leaf in leaves)
to_dtype = dtypes.result_type(*from_dtypes)
sizes = tuple(jnp.size(leaf) for leaf in leaves)
shapes = tuple(jnp.shape(leaf) for leaf in leaves)
indices = tuple(itertools.accumulate(sizes))
if all(dt == to_dtype for dt in from_dtypes):
# Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
# See https://github.com/google/jax/issues/7809.
raveled = jnp.concatenate([jnp.ravel(leaf) for leaf in leaves])
return (
raveled,
HashablePartial(_unravel_leaves_single_dtype, indices, shapes), # type: ignore[arg-type]
)
# When there is more than one distinct input dtype, we perform type conversions and produce a
# dtype-specific unravel function.
raveled = jnp.concatenate(
[jnp.ravel(lax.convert_element_type(leaf, to_dtype)) for leaf in leaves],
)
return (
raveled,
HashablePartial(_unravel_leaves, indices, shapes, from_dtypes, to_dtype), # type: ignore[arg-type]
)
def _unravel_empty(flat: Array, /) -> list[ArrayLike]:
if jnp.shape(flat) != (0,):
raise ValueError(
f'The unravel function expected an array of shape {(0,)}, got shape {jnp.shape(flat)}.',
)
return []
def _unravel_leaves_single_dtype(
indices: tuple[int, ...],
shapes: tuple[tuple[int, ...], ...],
flat: Array,
/,
) -> list[Array]:
if jnp.shape(flat) != (indices[-1],):
raise ValueError(
f'The unravel function expected an array of shape {(indices[-1],)}, '
f'got shape {jnp.shape(flat)}.',
)
chunks = jnp.split(flat, indices[:-1])
return [chunk.reshape(shape) for chunk, shape in safe_zip(chunks, shapes)]
def _unravel_leaves(
indices: tuple[int, ...],
shapes: tuple[tuple[int, ...], ...],
from_dtypes: tuple[jnp.dtype, ...],
to_dtype: jnp.dtype,
flat: Array,
/,
) -> list[Array]:
if jnp.shape(flat) != (indices[-1],):
raise ValueError(
f'The unravel function expected an array of shape {(indices[-1],)}, '
f'got shape {jnp.shape(flat)}.',
)
array_dtype = dtypes.dtype(flat)
if array_dtype != to_dtype:
raise ValueError(
f'The unravel function expected an array of dtype {to_dtype}, got dtype {array_dtype}.',
)
chunks = jnp.split(flat, indices[:-1])
with warnings.catch_warnings():
warnings.simplefilter('ignore') # ignore complex-to-real cast warning
return [
lax.convert_element_type(chunk.reshape(shape), dtype)
for chunk, shape, dtype in safe_zip(chunks, shapes, from_dtypes)
]
@@ -0,0 +1,218 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integration with NumPy."""
# pragma: numpy cover file
# pylint: disable=import-error
from __future__ import annotations
import functools
import itertools
import warnings
from typing import Any, Callable
from typing_extensions import TypeAlias # Python 3.10+
import numpy as np
from numpy.typing import ArrayLike
from optree.ops import tree_flatten, tree_unflatten
from optree.typing import PyTreeSpec, PyTreeTypeVar
from optree.utils import safe_zip
__all__ = ['ArrayLikeTree', 'ArrayTree', 'tree_ravel']
# pylint: disable-next=invalid-name
ArrayLikeTree: TypeAlias = PyTreeTypeVar('ArrayLikeTree', ArrayLike) # type: ignore[valid-type]
# pylint: disable-next=invalid-name
ArrayTree: TypeAlias = PyTreeTypeVar('ArrayTree', np.ndarray) # type: ignore[valid-type]
def tree_ravel(
tree: ArrayLikeTree,
/,
is_leaf: Callable[[Any], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = '',
) -> tuple[np.ndarray, Callable[[np.ndarray], ArrayTree]]:
r"""Ravel (flatten) a pytree of arrays down to a 1D array.
>>> tree = {
... 'layer1': {
... 'weight': np.arange(0, 6, dtype=np.float32).reshape((2, 3)),
... 'bias': np.arange(6, 8, dtype=np.float32).reshape((2,)),
... },
... 'layer2': {
... 'weight': np.arange(8, 10, dtype=np.float32).reshape((1, 2)),
... 'bias': np.arange(10, 11, dtype=np.float32).reshape((1,)),
... },
... }
>>> tree # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([6., 7.], dtype=float32)
},
'layer2': {
'weight': array([[8., 9.]], dtype=float32),
'bias': array([10.], dtype=float32)
}
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32)
>>> unravel_func(flat) # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([6., 7.], dtype=float32)
},
'layer2': {
'weight': array([[8., 9.]], dtype=float32),
'bias': array([10.], dtype=float32)
}
}
Args:
tree (pytree): a pytree of arrays and scalars to ravel.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list and :data:`None` will be remain in the result
pytree. (default: :data:`False`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)
Returns:
A pair ``(array, unravel_func)`` where the first element is a 1D array representing the
flattened and concatenated leaf values, with ``dtype`` determined by promoting the
``dtype``\s of leaf values, and the second element is a callable for unflattening a 1D array
of the same length back to a pytree of the same structure as the input ``tree``. If the
input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the
default dtype is returned in the first component of the output.
"""
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat, unravel_flat = _ravel_leaves(leaves)
return flat, functools.partial(_tree_unravel, treespec, unravel_flat)
ravel_pytree = tree_ravel
def _tree_unravel(
treespec: PyTreeSpec,
unravel_flat: Callable[[np.ndarray], list[np.ndarray]],
flat: np.ndarray,
/,
) -> ArrayTree:
return tree_unflatten(treespec, unravel_flat(flat))
def _ravel_leaves(
leaves: list[np.ndarray],
/,
) -> tuple[
np.ndarray,
Callable[[np.ndarray], list[np.ndarray]],
]:
if not leaves:
return (np.zeros(0), _unravel_empty)
from_dtypes = tuple(np.result_type(leaf) for leaf in leaves)
to_dtype = np.result_type(*leaves)
sizes = tuple(np.size(leaf) for leaf in leaves)
shapes = tuple(np.shape(leaf) for leaf in leaves)
indices = tuple(itertools.accumulate(sizes))
if all(dt == to_dtype for dt in from_dtypes):
# Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
raveled = np.concatenate([np.ravel(leaf) for leaf in leaves])
return (
raveled,
functools.partial(_unravel_leaves_single_dtype, indices, shapes),
)
# When there is more than one distinct input dtype, we perform type conversions and produce a
# dtype-specific unravel function.
raveled = np.concatenate([np.ravel(leaf).astype(to_dtype) for leaf in leaves])
return (
raveled,
functools.partial(_unravel_leaves, indices, shapes, from_dtypes, to_dtype),
)
def _unravel_empty(flat: np.ndarray, /) -> list[np.ndarray]:
if np.shape(flat) != (0,):
raise ValueError(
f'The unravel function expected an array of shape {(0,)}, got shape {np.shape(flat)}.',
)
return []
def _unravel_leaves_single_dtype(
indices: tuple[int, ...],
shapes: tuple[tuple[int, ...], ...],
flat: np.ndarray,
/,
) -> list[np.ndarray]:
if np.shape(flat) != (indices[-1],):
raise ValueError(
f'The unravel function expected an array of shape {(indices[-1],)}, '
f'got shape {np.shape(flat)}.',
)
chunks = np.split(flat, indices[:-1])
return [chunk.reshape(shape) for chunk, shape in safe_zip(chunks, shapes)]
def _unravel_leaves(
indices: tuple[int, ...],
shapes: tuple[tuple[int, ...], ...],
from_dtypes: tuple[np.dtype, ...],
to_dtype: np.dtype,
flat: np.ndarray,
/,
) -> list[np.ndarray]:
if np.shape(flat) != (indices[-1],):
raise ValueError(
f'The unravel function expected an array of shape {(indices[-1],)}, '
f'got shape {np.shape(flat)}.',
)
array_dtype = np.result_type(flat)
if array_dtype != to_dtype:
raise ValueError(
f'The unravel function expected an array of dtype {to_dtype}, got dtype {array_dtype}.',
)
chunks = np.split(flat, indices[:-1])
with warnings.catch_warnings():
warnings.simplefilter('ignore') # ignore complex-to-real cast warning
return [
chunk.reshape(shape).astype(dtype)
for chunk, shape, dtype in safe_zip(chunks, shapes, from_dtypes)
]
@@ -0,0 +1,222 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integration with PyTorch."""
# pragma: torch cover file
# pylint: disable=import-error
from __future__ import annotations
import functools
import warnings
from typing import Any, Callable
from typing_extensions import TypeAlias # Python 3.10+
import torch
from optree.ops import tree_flatten, tree_unflatten
from optree.typing import PyTreeSpec, PyTreeTypeVar
from optree.utils import safe_zip
__all__ = ['TensorTree', 'tree_ravel']
# pylint: disable-next=invalid-name
TensorTree: TypeAlias = PyTreeTypeVar('TensorTree', torch.Tensor) # type: ignore[valid-type]
def tree_ravel(
tree: TensorTree,
/,
is_leaf: Callable[[Any], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = '',
) -> tuple[torch.Tensor, Callable[[torch.Tensor], TensorTree]]:
r"""Ravel (flatten) a pytree of tensors down to a 1D tensor.
>>> tree = {
... 'layer1': {
... 'weight': torch.arange(0, 6, dtype=torch.float64).reshape((2, 3)),
... 'bias': torch.arange(6, 8, dtype=torch.float64).reshape((2,)),
... },
... 'layer2': {
... 'weight': torch.arange(8, 10, dtype=torch.float64).reshape((1, 2)),
... 'bias': torch.arange(10, 11, dtype=torch.float64).reshape((1,)),
... },
... }
>>> tree # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([6., 7.], dtype=torch.float64)
},
'layer2': {
'weight': tensor([[8., 9.]], dtype=torch.float64),
'bias': tensor([10.], dtype=torch.float64)
}
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
tensor([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=torch.float64)
>>> unravel_func(flat) # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([6., 7.], dtype=torch.float64)
},
'layer2': {
'weight': tensor([[8., 9.]], dtype=torch.float64),
'bias': tensor([10.], dtype=torch.float64)
}
}
Args:
tree (pytree): a pytree of tensors to ravel.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list and :data:`None` will be remain in the result
pytree. (default: :data:`False`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)
Returns:
A pair ``(tensor, unravel_func)`` where the first element is a 1D tensor representing the
flattened and concatenated leaf values, with ``dtype`` determined by promoting the
``dtype``\s of leaf values, and the second element is a callable for unflattening a 1D tensor
of the same length back to a pytree of the same structure as the input ``tree``. If the
input pytree is empty (i.e. has no leaves) then as a convention a 1D empty tensor of the
default dtype is returned in the first component of the output.
"""
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat, unravel_flat = _ravel_leaves(leaves)
return flat, functools.partial(_tree_unravel, treespec, unravel_flat)
ravel_pytree = tree_ravel
def _tree_unravel(
treespec: PyTreeSpec,
unravel_flat: Callable[[torch.Tensor], list[torch.Tensor]],
flat: torch.Tensor,
/,
) -> TensorTree:
return tree_unflatten(treespec, unravel_flat(flat))
def _ravel_leaves(
leaves: list[torch.Tensor],
/,
) -> tuple[
torch.Tensor,
Callable[[torch.Tensor], list[torch.Tensor]],
]:
if not leaves:
return (torch.zeros(0), _unravel_empty)
if not all(torch.is_tensor(leaf) for leaf in leaves):
raise ValueError('All leaves must be tensors.')
from_dtypes = tuple(leaf.dtype for leaf in leaves)
to_dtype = from_dtypes[0]
for from_dtype in from_dtypes[1:]:
to_dtype = torch.promote_types(to_dtype, from_dtype)
sizes = tuple(leaf.numel() for leaf in leaves)
shapes = tuple(leaf.shape for leaf in leaves)
if all(dt == to_dtype for dt in from_dtypes):
# Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
raveled = torch.cat([torch.ravel(leaf) for leaf in leaves])
return (
raveled,
functools.partial(_unravel_leaves_single_dtype, sizes, shapes),
)
# When there is more than one distinct input dtype, we perform type conversions and produce a
# dtype-specific unravel function.
raveled = torch.cat([torch.ravel(leaf).to(to_dtype) for leaf in leaves])
return (
raveled,
functools.partial(_unravel_leaves, sizes, shapes, from_dtypes, to_dtype),
)
def _unravel_empty(flat: torch.Tensor, /) -> list[torch.Tensor]:
if not torch.is_tensor(flat):
raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.')
if flat.shape != (0,):
raise ValueError(
f'The unravel function expected a tensor of shape {(0,)}, got shape {flat.shape}.',
)
return []
def _unravel_leaves_single_dtype(
sizes: tuple[int, ...],
shapes: tuple[tuple[int, ...], ...],
flat: torch.Tensor,
/,
) -> list[torch.Tensor]:
if not torch.is_tensor(flat):
raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.')
if flat.shape != (sum(sizes),):
raise ValueError(
f'The unravel function expected a tensor of shape {(sum(sizes),)}, '
f'got shape {flat.shape}.',
)
chunks = torch.split(flat, list(sizes))
return [chunk.reshape(shape) for chunk, shape in safe_zip(chunks, shapes)]
def _unravel_leaves(
sizes: tuple[int, ...],
shapes: tuple[tuple[int, ...], ...],
from_dtypes: tuple[torch.dtype, ...],
to_dtype: torch.dtype,
flat: torch.Tensor,
/,
) -> list[torch.Tensor]:
if not torch.is_tensor(flat):
raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.')
if flat.shape != (sum(sizes),):
raise ValueError(
f'The unravel function expected a tensor of shape {(sum(sizes),)}, '
f'got shape {flat.shape}.',
)
if flat.dtype != to_dtype:
raise ValueError(
f'The unravel function expected a tensor of dtype {to_dtype}, got dtype {flat.dtype}.',
)
chunks = torch.split(flat, list(sizes))
with warnings.catch_warnings():
warnings.simplefilter('ignore') # ignore complex-to-real cast warning
return [
chunk.reshape(shape).to(dtype)
for chunk, shape, dtype in safe_zip(chunks, shapes, from_dtypes)
]
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,392 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Utilities for working with ``PyTree``\s.
The :mod:`optree.pytree` namespace contains aliases of ``optree.tree_*`` utilities.
>>> import optree.pytree as pytree
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> leaves, treespec = pytree.flatten(tree)
>>> leaves, treespec # doctest: +IGNORE_WHITESPACE
(
[1, 2, 3, 4, 5],
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> tree == pytree.unflatten(treespec, leaves)
True
.. versionadded:: 0.14.1
"""
from __future__ import annotations
import functools as _functools
import inspect as _inspect
import sys as _sys
from builtins import all as _all
from types import ModuleType as _ModuleType
from typing import TYPE_CHECKING as _TYPE_CHECKING
import optree.dataclasses as dataclasses
import optree.functools as functools
from optree.accessors import PyTreeEntry
from optree.ops import tree_accessors as accessors
from optree.ops import tree_all as all # pylint: disable=redefined-builtin
from optree.ops import tree_any as any # pylint: disable=redefined-builtin
from optree.ops import tree_broadcast_common as broadcast_common
from optree.ops import tree_broadcast_map as broadcast_map
from optree.ops import tree_broadcast_map_with_accessor as broadcast_map_with_accessor
from optree.ops import tree_broadcast_map_with_path as broadcast_map_with_path
from optree.ops import tree_broadcast_prefix as broadcast_prefix
from optree.ops import tree_flatten as flatten
from optree.ops import tree_flatten_one_level as flatten_one_level
from optree.ops import tree_flatten_with_accessor as flatten_with_accessor
from optree.ops import tree_flatten_with_path as flatten_with_path
from optree.ops import tree_is_leaf as is_leaf
from optree.ops import tree_iter as iter # pylint: disable=redefined-builtin
from optree.ops import tree_leaves as leaves
from optree.ops import tree_map as map # pylint: disable=redefined-builtin
from optree.ops import tree_map_ as map_
from optree.ops import tree_map_with_accessor as map_with_accessor
from optree.ops import tree_map_with_accessor_ as map_with_accessor_
from optree.ops import tree_map_with_path as map_with_path
from optree.ops import tree_map_with_path_ as map_with_path_
from optree.ops import tree_max as max # pylint: disable=redefined-builtin
from optree.ops import tree_min as min # pylint: disable=redefined-builtin
from optree.ops import tree_partition as partition
from optree.ops import tree_paths as paths
from optree.ops import tree_reduce as reduce
from optree.ops import tree_replace_nones as replace_nones
from optree.ops import tree_structure as structure
from optree.ops import tree_sum as sum # pylint: disable=redefined-builtin
from optree.ops import tree_transpose as transpose
from optree.ops import tree_transpose_map as transpose_map
from optree.ops import tree_transpose_map_with_accessor as transpose_map_with_accessor
from optree.ops import tree_transpose_map_with_path as transpose_map_with_path
from optree.ops import tree_unflatten as unflatten
from optree.registry import dict_insertion_ordered
from optree.registry import register_pytree_node as register_node
from optree.registry import register_pytree_node_class as register_node_class
from optree.registry import unregister_pytree_node as unregister_node
from optree.typing import PyTreeKind, PyTreeSpec
from optree.version import __version__ as __version__ # pylint: disable=useless-import-alias
__all__ = [
'reexport',
'PyTreeSpec',
'PyTreeKind',
'PyTreeEntry',
'flatten',
'flatten_with_path',
'flatten_with_accessor',
'unflatten',
'iter',
'leaves',
'structure',
'paths',
'accessors',
'is_leaf',
'map',
'map_',
'map_with_path',
'map_with_path_',
'map_with_accessor',
'map_with_accessor_',
'replace_nones',
'partition',
'transpose',
'transpose_map',
'transpose_map_with_path',
'transpose_map_with_accessor',
'broadcast_prefix',
'broadcast_common',
'broadcast_map',
'broadcast_map_with_path',
'broadcast_map_with_accessor',
'reduce',
'sum',
'max',
'min',
'all',
'any',
'flatten_one_level',
'register_node',
'register_node_class',
'unregister_node',
'dict_insertion_ordered',
]
if _TYPE_CHECKING:
from collections.abc import Callable, Iterable
from typing import Any, TypeVar # pylint: disable=ungrouped-imports
from typing_extensions import ParamSpec # Python 3.10+
_P = ParamSpec('_P')
_T = TypeVar('_T')
class ReexportedModule(_ModuleType):
"""A module that re-exports APIs from another module."""
__doc__: str
def __init__(
self,
/,
name: str,
*,
namespace: str,
original: _ModuleType,
doc: str | None = None,
__all__: Iterable[str] | None = None,
__dir__: Iterable[str] | None = None,
extra_members: dict[str, Any] | None = None,
) -> None:
doc = doc or (
f'Re-exports :mod:`{original.__name__}` as :mod:`{name}` '
f'with namespace :const:`{namespace!r}`.'
)
super().__init__(name, doc)
if __all__ is None: # pragma: no branch
__all__ = {n for n in original.__all__ if n != 'reexport'}
__all__ = set(__all__)
if __dir__ is None: # pragma: no branch
__dir__ = {n for n in original.__dir__() if not n.startswith('_') and n != 'reexport'}
__dir__ = set(__dir__).intersection(__all__)
if extra_members:
for key, value in extra_members.items():
setattr(self, key, value)
__dir__.update(extra_members)
self.__namespace = namespace
self.__original = original
self.__all_set = __all__
self.__all = sorted(__all__)
self.__dir = sorted(__dir__)
@property
def __all__(self, /) -> list[str]:
"""Return the list of attributes available in this module."""
return self.__all
def __dir__(self, /) -> list[str]:
"""Return the list of attributes available in this module."""
return self.__dir.copy()
def __getattr__(self, name: str, /) -> Any:
"""Get an attribute from the re-exported module."""
if name in self.__all_set:
attr = getattr(self.__original, name)
if _inspect.isfunction(attr):
attr = self.__reexport__(attr)
setattr(self, name, attr)
return attr
raise AttributeError(f'module {self.__name__!r} has no attribute {name!r}')
def __reexport__(self, func: Callable[_P, _T], /) -> Callable[_P, _T]:
"""Re-export a function with the default namespace."""
sig = _inspect.signature(func)
if 'namespace' not in sig.parameters:
@_functools.wraps(func)
def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
return func(*args, **kwargs)
else:
@_functools.wraps(func)
def wrapped( # type: ignore[valid-type]
*args: _P.args,
namespace: str = self.__namespace,
**kwargs: _P.kwargs,
) -> _T:
return func(*args, namespace=namespace, **kwargs) # type: ignore[arg-type]
if func.__doc__: # pragma: no branch
wrapped.__doc__ = func.__doc__.replace(
"(default: :const:`''`, i.e., the global namespace)",
f'(default: :const:`{self.__namespace!r}`)',
)
wrapped.__signature__ = sig.replace( # type: ignore[attr-defined]
parameters=[
p if p.name != 'namespace' else p.replace(default=self.__namespace)
for p in sig.parameters.values()
],
)
if callable(getattr(func, 'get', None)):
wrapped.get = self.__reexport__(func.get) # type: ignore[attr-defined]
return wrapped
if _TYPE_CHECKING:
# pylint: disable-next=missing-class-docstring,too-few-public-methods
class ReexportedPyTreeModule(ReexportedModule):
__version__: str
functools: _ModuleType
dataclasses: _ModuleType
PyTreeSpec: type[PyTreeSpec] = PyTreeSpec
PyTreeKind: type[PyTreeKind] = PyTreeKind
PyTreeEntry: type[PyTreeEntry] = PyTreeEntry
flatten = staticmethod(flatten)
flatten_with_path = staticmethod(flatten_with_path)
flatten_with_accessor = staticmethod(flatten_with_accessor)
unflatten = staticmethod(unflatten)
iter = staticmethod(iter)
leaves = staticmethod(leaves)
structure = staticmethod(structure)
paths = staticmethod(paths)
accessors = staticmethod(accessors)
is_leaf = staticmethod(is_leaf)
map = staticmethod(map)
map_ = staticmethod(map_)
map_with_path = staticmethod(map_with_path)
map_with_path_ = staticmethod(map_with_path_)
map_with_accessor = staticmethod(map_with_accessor)
map_with_accessor_ = staticmethod(map_with_accessor_)
replace_nones = staticmethod(replace_nones)
partition = staticmethod(partition)
transpose = staticmethod(transpose)
transpose_map = staticmethod(transpose_map)
transpose_map_with_path = staticmethod(transpose_map_with_path)
transpose_map_with_accessor = staticmethod(transpose_map_with_accessor)
broadcast_prefix = staticmethod(broadcast_prefix)
broadcast_common = staticmethod(broadcast_common)
broadcast_map = staticmethod(broadcast_map)
broadcast_map_with_path = staticmethod(broadcast_map_with_path)
broadcast_map_with_accessor = staticmethod(broadcast_map_with_accessor)
reduce = staticmethod(reduce)
sum = staticmethod(sum)
max = staticmethod(max)
min = staticmethod(min)
all = staticmethod(all)
any = staticmethod(any)
flatten_one_level = staticmethod(flatten_one_level)
register_node = staticmethod(register_node)
register_node_class = staticmethod(register_node_class)
unregister_node = staticmethod(unregister_node)
dict_insertion_ordered = staticmethod(dict_insertion_ordered)
def reexport(*, namespace: str, module: str | None = None) -> ReexportedPyTreeModule:
"""Re-export a pytree utility module with the given namespace as default."""
raise NotImplementedError('reexport() is not available in type checking mode')
else:
def reexport(*, namespace: str, module: str | None = None) -> _ModuleType: # type: ignore[misc]
"""Re-export a pytree utility module with the given namespace as default.
>>> import optree
>>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
>>> pytree.flatten({'a': 1, 'b': 2})
([1, 2], PyTreeSpec({'a': *, 'b': *}))
This function is useful for downstream libraries that want to re-export the pytree utilities
with their own namespace:
.. code-block:: python
# foo/__init__.py
import optree
pytree = optree.pytree.reexport(namespace='foo')
del optree
# foo/bar.py
from foo import pytree
@pytree.dataclasses.dataclass
class Bar:
a: int
b: float
# User code
In [1]: import foo
In [2]: foo.pytree.flatten({'a': 1, 'b': 2, 'c': foo.bar.Bar(3, 4.0)}))
Out[2]:
(
[1, 2, 3, 4.0],
PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo')
)
In [3]: foo.pytree.functools.reduce(lambda x, y: x * y, {'a': 1, 'b': 2, 'c': foo.bar.Bar(3, 4.0)}))
Out[3]: 24.0
.. versionadded:: 0.16.0
Args:
namespace (str): The namespace to use in the re-exported module.
module (str, optional): The name of the re-exported module.
If not provided, defaults to ``<caller_module>.pytree``. The caller module is determined
by inspecting the stack frame.
Returns:
The re-exported module.
"""
# pylint: disable-next=import-outside-toplevel
from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE
if namespace is GLOBAL_NAMESPACE:
namespace = ''
elif not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if module is None:
try:
# pylint: disable-next=protected-access
caller_module = _sys._getframemodulename(1) or '__main__' # type: ignore[attr-defined]
except AttributeError: # pragma: no cover
try:
# pylint: disable-next=protected-access
caller_module = _sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
caller_module = '__main__'
module = f'{caller_module}.pytree'
if not module or not _all(part.isidentifier() for part in module.split('.')):
raise ValueError(f'invalid module name: {module!r}')
for module_name in (module, f'{module}.dataclasses', f'{module}.functools'):
if module_name in _sys.modules:
raise ValueError(f'module {module_name!r} already exists')
reexported_dataclasses = ReexportedModule(
f'{module}.dataclasses',
namespace=namespace,
original=dataclasses,
)
reexported_functools = ReexportedModule(
f'{module}.functools',
namespace=namespace,
original=functools,
)
mod: ReexportedPyTreeModule = ReexportedModule( # type: ignore[assignment]
module,
namespace=namespace,
original=_sys.modules[__name__],
extra_members={
'__version__': __version__,
'dataclasses': reexported_dataclasses,
'functools': reexported_functools,
},
)
_sys.modules[module] = mod
_sys.modules[f'{module}.dataclasses'] = reexported_dataclasses
_sys.modules[f'{module}.functools'] = reexported_functools
return mod
@@ -0,0 +1,953 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry for custom pytree node types."""
# pylint: disable=too-many-lines
from __future__ import annotations
import contextlib
import dataclasses
import functools
import inspect
import sys
from collections import OrderedDict, defaultdict, deque, namedtuple
from operator import itemgetter, methodcaller
from threading import Lock
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, NamedTuple, TypeVar, overload
import optree._C as _C
from optree.accessors import (
AutoEntry,
MappingEntry,
NamedTupleEntry,
PyTreeEntry,
SequenceEntry,
StructSequenceEntry,
)
from optree.typing import (
Children,
MetaData,
PyTreeKind,
StructSequence,
T,
is_namedtuple_class,
is_structseq_class,
)
from optree.utils import safe_zip, total_order_sorted, unzip2
if TYPE_CHECKING:
import builtins
from collections.abc import Collection, Generator, Iterable
from optree.typing import KT, VT, CustomTreeNode, FlattenFunc, UnflattenFunc
# pylint: disable-next=invalid-name
CustomTreeNodeType = TypeVar('CustomTreeNodeType', bound=type[CustomTreeNode])
__all__ = [
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'dict_insertion_ordered',
]
SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} # Python 3.10+
@dataclasses.dataclass(init=True, repr=True, eq=True, frozen=True, **SLOTS)
class PyTreeNodeRegistryEntry(Generic[T]):
"""A dataclass that stores the information of a pytree node type."""
type: builtins.type[Collection[T]]
flatten_func: FlattenFunc[T]
unflatten_func: UnflattenFunc[T]
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
_: dataclasses.KW_ONLY # Python 3.10+
path_entry_type: builtins.type[PyTreeEntry] = AutoEntry
kind: PyTreeKind = PyTreeKind.CUSTOM
namespace: str = ''
del SLOTS
# pylint: disable-next=missing-class-docstring,too-few-public-methods
class GlobalNamespace: # pragma: no cover
__slots__: ClassVar[tuple[()]] = ()
def __repr__(self, /) -> str:
return '<GLOBAL NAMESPACE>'
__GLOBAL_NAMESPACE: str = GlobalNamespace() # type: ignore[assignment]
__REGISTRY_LOCK: Lock = Lock()
del GlobalNamespace
if TYPE_CHECKING:
from typing_extensions import ParamSpec # Python 3.10+
_P = ParamSpec('_P')
_T = TypeVar('_T')
_GetP = ParamSpec('_GetP')
_GetT = TypeVar('_GetT')
class _CallableWithGet(Generic[_P, _T, _GetP, _GetT]):
def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
raise NotImplementedError
# pylint: disable-next=missing-function-docstring
def get(self, /, *args: _GetP.args, **kwargs: _GetP.kwargs) -> _GetT:
raise NotImplementedError
def _add_get(
get: Callable[_GetP, _GetT],
/,
) -> Callable[
[Callable[_P, _T]],
_CallableWithGet[_P, _T, _GetP, _GetT],
]:
def decorator(func: Callable[_P, _T], /) -> _CallableWithGet[_P, _T, _GetP, _GetT]:
func.get = get # type: ignore[attr-defined]
return func # type: ignore[return-value]
return decorator
@overload
def pytree_node_registry_get(
cls: type,
/,
*,
namespace: str = '',
) -> PyTreeNodeRegistryEntry | None: ...
@overload
def pytree_node_registry_get(
cls: None = None,
/,
*,
namespace: str = '',
) -> dict[type, PyTreeNodeRegistryEntry]: ...
# pylint: disable-next=too-many-return-statements,too-many-branches
def pytree_node_registry_get( # noqa: C901
cls: type | None = None,
/,
*,
namespace: str = '',
) -> dict[type, PyTreeNodeRegistryEntry] | PyTreeNodeRegistryEntry | None:
"""Lookup the pytree node registry.
>>> register_pytree_node.get() # doctest: +IGNORE_WHITESPACE,ELLIPSIS
{
<class 'NoneType'>: PyTreeNodeRegistryEntry(
type=<class 'NoneType'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.PyTreeEntry'>,
kind=<PyTreeKind.NONE: 2>,
namespace=''
),
<class 'tuple'>: PyTreeNodeRegistryEntry(
type=<class 'tuple'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.SequenceEntry'>,
kind=<PyTreeKind.TUPLE: 3>,
namespace=''
),
<class 'list'>: PyTreeNodeRegistryEntry(
type=<class 'list'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.SequenceEntry'>,
kind=<PyTreeKind.LIST: 4>,
namespace=''
),
...
}
>>> register_pytree_node.get(defaultdict) # doctest: +IGNORE_WHITESPACE,ELLIPSIS
PyTreeNodeRegistryEntry(
type=<class 'collections.defaultdict'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.MappingEntry'>,
kind=<PyTreeKind.DEFAULTDICT: 8>,
namespace=''
)
>>> register_pytree_node.get(frozenset) # frozenset is considered as a leaf node
None
Args:
cls (type or None, optional): The class of the pytree node to retrieve. If not provided, all
the registered pytree nodes in the namespace are returned.
namespace (str, optional): The namespace of the registry to retrieve. If not provided, the
global namespace is used.
Returns:
If the ``cls`` is not provided, a dictionary of all the registered pytree nodes in the
namespace is returned. If the ``cls`` is provided, the corresponding registry entry is
returned if the ``cls`` is registered as a pytree node. Otherwise, :data:`None` is returned,
i.e., the ``cls`` is represented as a leaf node.
"""
if namespace is __GLOBAL_NAMESPACE:
namespace = ''
if (
cls is not None
and cls is not namedtuple # noqa: PYI024
and not inspect.isclass(cls)
):
raise TypeError(f'Expected a class or None, got {cls!r}.') # pragma: !=3.9 cover
if not isinstance(namespace, str):
raise TypeError( # pragma: !=3.9 cover
f'The namespace must be a string, got {namespace!r}.',
)
if cls is None:
namespaces = frozenset({namespace, ''})
with __REGISTRY_LOCK:
registry = {
handler.type: handler
for handler in _NODETYPE_REGISTRY.values()
if handler.namespace in namespaces
}
if _C.is_dict_insertion_ordered(namespace):
registry[dict] = _DICT_INSERTION_ORDERED_REGISTRY_ENTRY
registry[defaultdict] = _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY
return registry
if namespace != '':
handler = _NODETYPE_REGISTRY.get((namespace, cls))
if handler is not None:
return handler
if _C.is_dict_insertion_ordered(namespace):
if cls is dict:
return _DICT_INSERTION_ORDERED_REGISTRY_ENTRY
if cls is defaultdict:
return _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY
handler = _NODETYPE_REGISTRY.get(cls)
if handler is not None:
return handler
if is_structseq_class(cls):
return _NODETYPE_REGISTRY.get(StructSequence)
if is_namedtuple_class(cls):
return _NODETYPE_REGISTRY.get(namedtuple) # type: ignore[call-overload] # noqa: PYI024
return None
@_add_get(pytree_node_registry_get)
def register_pytree_node(
cls: type[Collection[T]],
/,
flatten_func: FlattenFunc[T],
unflatten_func: UnflattenFunc[T],
*,
path_entry_type: type[PyTreeEntry] = AutoEntry,
namespace: str,
) -> type[Collection[T]]:
"""Extend the set of types that are considered internal nodes in pytrees.
See also :func:`register_pytree_node_class` and :func:`unregister_pytree_node`.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
the same class in different namespaces for different use cases.
.. warning::
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
prevent accidental collisions between different libraries that may register the same type.
Args:
cls (type): A Python type to treat as an internal pytree node.
flatten_func (callable): A function to be used during flattening, taking an instance of ``cls``
and returning a triple or optionally a pair, with (1) an iterable for the children to be
flattened recursively, and (2) some hashable metadata to be stored in the treespec and
to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree path
entries to the corresponding children. If the entries are not provided or given by
:data:`None`, then `range(len(children))` will be used.
unflatten_func (callable): A function taking two arguments: the metadata that was returned
by ``flatten_func`` and stored in the treespec, and the unflattened children. The
function should return an instance of ``cls``.
path_entry_type (type, optional): The type of the path entry to be used in the treespec.
(default: :class:`AutoEntry`)
namespace (str): A non-empty string that uniquely identifies the namespace of the type registry.
This is used to isolate the registry from other modules that might register a different
custom behavior for the same type.
Returns:
The same type as the input ``cls``.
Raises:
TypeError: If the input type is not a class.
TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`.
TypeError: If the namespace is not a string.
ValueError: If the namespace is an empty string.
ValueError: If the type is already registered in the registry.
.. versionadded:: 0.12.0
The ``path_entry_type`` argument to specify the path entry type used in
:meth:`PyTreeSpec.accessors` and :func:`tree_flatten_with_accessor`.
If not provided, :class:`AutoEntry` will be used.
Examples:
>>> # Registry a Python type with lambda functions
>>> register_pytree_node(
... set,
... lambda s: (sorted(s), None, None),
... lambda _, children: set(children),
... namespace='set',
... )
<class 'set'>
>>> # Register a Python type into a namespace
>>> import torch
>>> register_pytree_node(
... torch.Tensor,
... flatten_func=lambda tensor: (
... (tensor.cpu().detach().numpy(),),
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
... ),
... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
... namespace='torch2numpy',
... )
<class 'torch.Tensor'>
>>> # doctest: +SKIP
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>>> # Flatten without specifying the namespace
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
>>> # Flatten with the namespace
>>> tree_flatten(tree, namespace='torch2numpy')
(
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
},
namespace='torch2numpy'
)
)
>>> # Register the same type with a different namespace for different behaviors
>>> def tensor2flatparam(tensor):
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
...
... def flatparam2tensor(metadata, children):
... return children[0].reshape(metadata)
...
... register_pytree_node(
... torch.Tensor,
... flatten_func=tensor2flatparam,
... unflatten_func=flatparam2tensor,
... namespace='tensor2flatparam',
... )
<class 'torch.Tensor'>
>>> # Flatten with the new namespace
>>> tree_flatten(tree, namespace='tensor2flatparam')
(
[
Parameter containing: tensor([0., 0.], requires_grad=True),
Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
},
namespace='tensor2flatparam'
)
)
""" # pylint: disable=line-too-long
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)):
raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.')
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
registration_key: type | tuple[str, type]
if namespace is __GLOBAL_NAMESPACE:
registration_key = cls
namespace = ''
else:
registration_key = (namespace, cls)
with __REGISTRY_LOCK:
_C.register_node(
cls,
flatten_func,
unflatten_func,
path_entry_type,
namespace,
)
_NODETYPE_REGISTRY[registration_key] = PyTreeNodeRegistryEntry(
cls,
flatten_func,
unflatten_func,
path_entry_type=path_entry_type,
namespace=namespace,
)
return cls
del pytree_node_registry_get, _add_get
@overload
def register_pytree_node_class(
cls: str | None = None,
/,
*,
path_entry_type: type[PyTreeEntry] | None = None,
namespace: str | None = None,
) -> Callable[[CustomTreeNodeType], CustomTreeNodeType]: ...
@overload
def register_pytree_node_class(
cls: CustomTreeNodeType,
/,
*,
path_entry_type: type[PyTreeEntry] | None,
namespace: str,
) -> CustomTreeNodeType: ...
# pylint: disable-next=too-many-branches
def register_pytree_node_class( # noqa: C901
cls: CustomTreeNodeType | str | None = None,
/,
*,
path_entry_type: type[PyTreeEntry] | None = None,
namespace: str | None = None,
) -> CustomTreeNodeType | Callable[[CustomTreeNodeType], CustomTreeNodeType]:
"""Extend the set of types that are considered internal nodes in pytrees.
See also :func:`register_pytree_node` and :func:`unregister_pytree_node`.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
the same class in different namespaces for different use cases.
.. warning::
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
prevent accidental collisions between different libraries that may register the same type.
Args:
cls (type, optional): A Python type to treat as an internal pytree node.
path_entry_type (type, optional): The type of the path entry to be used in the treespec.
(default: :class:`AutoEntry`)
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
type registry. This is used to isolate the registry from other modules that might
register a different custom behavior for the same type.
Returns:
The same type as the input ``cls`` if the argument presents. Otherwise, return a decorator
function that registers the class as a pytree node.
Raises:
TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`.
TypeError: If the namespace is not a string.
TypeError: If the class does not define the required method pairs.
ValueError: If the namespace is an empty string.
ValueError: If the type is already registered in the registry.
.. versionadded:: 0.12.0
The ``TREE_PATH_ENTRY_TYPE`` class variable to specify the path entry type used in
:meth:`PyTreeSpec.accessors` and :func:`tree_flatten_with_accessor`.
If not provided, :class:`AutoEntry` will be used.
.. versionadded:: 0.18.0
Previously, this function looked for methods named ``tree_flatten`` and ``tree_unflatten``
for the given class. Since version 0.18.0, it prefers methods named ``__tree_flatten__``
and ``__tree_unflatten__`` instead. The old method names are still supported for
backward compatibility, but it is recommended to use the new method names.
The method resolution follows this priority:
1. If both ``__tree_flatten__`` and ``__tree_unflatten__`` are defined, use them directly.
2. If both ``tree_flatten`` and ``tree_unflatten`` are defined, wrap them as dunder methods.
3. If neither complete pair is available, raise a :exc:`TypeError` suggesting the new method names.
This function is a thin wrapper around :func:`register_pytree_node`, and provides a
class-oriented interface:
.. code-block:: python
@register_pytree_node_class(namespace='foo')
class Special:
TREE_PATH_ENTRY_TYPE = GetAttrEntry
def __init__(self, x, y):
self.x = x
self.y = y
def __tree_flatten__(self):
return ((self.x, self.y), None, ('x', 'y'))
@classmethod
def __tree_unflatten__(cls, metadata, children):
return cls(*children)
@register_pytree_node_class('mylist')
class MyList(UserList):
TREE_PATH_ENTRY_TYPE = SequenceEntry
def __tree_flatten__(self):
return self.data, None, None
@classmethod
def __tree_unflatten__(cls, metadata, children):
return cls(*children)
# Legacy style (still supported but not recommended)
@register_pytree_node_class(namespace='legacy')
class LegacyStyleMyList(UserList):
def tree_flatten(self):
# Implementation automatically wrapped as __tree_flatten__
return self.data, None, None
@classmethod
def tree_unflatten(cls, metadata, children):
# Implementation automatically wrapped as __tree_unflatten__
return cls(*children)
"""
if cls is __GLOBAL_NAMESPACE or isinstance(cls, str):
if namespace is not None:
raise ValueError('Cannot specify `namespace` when the first argument is a string.')
if cls == '':
raise ValueError('The namespace cannot be an empty string.')
cls, namespace = None, cls
if namespace is None:
raise ValueError('Must specify `namespace` when the first argument is a class.')
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
if cls is None:
def decorator(cls: CustomTreeNodeType, /) -> CustomTreeNodeType:
return register_pytree_node_class(
cls,
path_entry_type=path_entry_type,
namespace=namespace,
)
return decorator
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if path_entry_type is None:
path_entry_type = getattr(cls, 'TREE_PATH_ENTRY_TYPE', AutoEntry)
if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)):
raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.')
# Check for dunder-styled methods first (preferred since 0.18.0)
if not all(
callable(getattr(cls, method, None))
for method in ('__tree_flatten__', '__tree_unflatten__')
):
# Check for old-styled methods (backward compatibility)
if not all(
callable(getattr(cls, method, None)) for method in ('tree_flatten', 'tree_unflatten')
):
raise TypeError(
f'{cls!r} must define both `__tree_flatten__` and `__tree_unflatten__` methods '
'for registration as a pytree node.',
)
# Add dunder-styled wrapper methods to the class
# pylint: disable=no-member
@functools.wraps(cls.tree_flatten)
def __tree_flatten__( # noqa: N807
self: CustomTreeNode[T],
/,
) -> tuple[Children[T], MetaData] | tuple[Children[T], MetaData, Iterable[Any] | None]:
return self.tree_flatten() # type: ignore[attr-defined]
@classmethod # type: ignore[misc]
@functools.wraps(getattr(cls.tree_unflatten, '__func__', cls.tree_unflatten))
def __tree_unflatten__( # noqa: N807
cls: type[CustomTreeNode[T]],
metadata: MetaData,
children: Children[T],
/,
) -> CustomTreeNode[T]:
return cls.tree_unflatten(metadata, children) # type: ignore[attr-defined]
# pylint: enable=no-member
cls.__tree_flatten__ = __tree_flatten__
cls.__tree_unflatten__ = __tree_unflatten__
register_pytree_node(
cls,
methodcaller('__tree_flatten__'),
cls.__tree_unflatten__,
path_entry_type=path_entry_type,
namespace=namespace,
)
return cls
def unregister_pytree_node(cls: type, /, *, namespace: str) -> PyTreeNodeRegistryEntry:
"""Remove a type from the pytree node registry.
See also :func:`register_pytree_node` and :func:`register_pytree_node_class`.
This function is the inverse operation of function :func:`register_pytree_node`.
Args:
cls (type): A Python type to remove from the pytree node registry.
namespace (str): The namespace of the pytree node registry to remove the type from.
Returns:
The removed registry entry.
Raises:
TypeError: If the input type is not a class.
TypeError: If the namespace is not a string.
ValueError: If the namespace is an empty string.
ValueError: If the type is a built-in type that cannot be unregistered.
ValueError: If the type is not found in the registry.
Examples:
>>> # Register a Python type with lambda functions
>>> register_pytree_node(
... set,
... lambda s: (sorted(s), None, None),
... lambda _, children: set(children),
... namespace='temp',
... )
<class 'set'>
>>> # Unregister the Python type
>>> unregister_pytree_node(set, namespace='temp')
"""
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
registration_key: type | tuple[str, type]
if namespace is __GLOBAL_NAMESPACE:
registration_key = cls
namespace = ''
else:
registration_key = (namespace, cls)
with __REGISTRY_LOCK:
_C.unregister_node(cls, namespace)
return _NODETYPE_REGISTRY.pop(registration_key)
@contextlib.contextmanager
def dict_insertion_ordered(mode: bool, /, *, namespace: str) -> Generator[None]:
"""Context manager to temporarily set the dictionary sorting mode.
This context manager is used to temporarily set the dictionary sorting mode for a specific
namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary
should be sorted or keeping the insertion order when flattening a pytree.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
(
[1, 2, 3, 4, 5],
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> with dict_insertion_ordered(True, namespace='some-namespace'): # doctest: +IGNORE_WHITESPACE
... tree_flatten(tree, namespace='some-namespace')
(
[2, 3, 4, 1, 5],
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace')
)
.. warning::
The dictionary sorting mode is a global setting and is **not thread-safe**. It is
recommended to use this context manager in a single-threaded environment.
Args:
mode (bool): The dictionary sorting mode to set.
namespace (str): The namespace to set the dictionary sorting mode for.
"""
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
if namespace is __GLOBAL_NAMESPACE:
namespace = ''
with __REGISTRY_LOCK:
prev = _C.is_dict_insertion_ordered(namespace, inherit_global_namespace=False)
_C.set_dict_insertion_ordered(bool(mode), namespace)
try:
yield
finally:
with __REGISTRY_LOCK:
_C.set_dict_insertion_ordered(prev, namespace)
def _sorted_items(items: Iterable[tuple[KT, VT]], /) -> list[tuple[KT, VT]]:
return total_order_sorted(items, key=itemgetter(0))
def _none_flatten(_: None, /) -> tuple[tuple[()], None]:
return (), None
def _none_unflatten(_: None, children: Iterable[Any], /) -> None:
sentinel = object()
if next(iter(children), sentinel) is not sentinel:
raise ValueError('Expected no children.')
def _tuple_flatten(tup: tuple[T, ...], /) -> tuple[tuple[T, ...], None]:
return tup, None
def _tuple_unflatten(_: None, children: Iterable[T], /) -> tuple[T, ...]:
return tuple(children)
def _list_flatten(lst: list[T], /) -> tuple[list[T], None]:
return lst, None
def _list_unflatten(_: None, children: Iterable[T], /) -> list[T]:
return list(children)
def _dict_flatten(dct: dict[KT, VT], /) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
keys, values = unzip2(_sorted_items(dct.items()))
return values, list(keys), keys
def _dict_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]:
return dict(safe_zip(keys, values))
def _dict_insertion_ordered_flatten(
dct: dict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
list[KT],
tuple[KT, ...],
]:
keys, values = unzip2(dct.items())
return values, list(keys), keys
def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]:
return dict(safe_zip(keys, values))
def _ordereddict_flatten(
dct: OrderedDict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
list[KT],
tuple[KT, ...],
]:
keys, values = unzip2(dct.items())
return values, list(keys), keys
def _ordereddict_unflatten(keys: list[KT], values: Iterable[VT], /) -> OrderedDict[KT, VT]:
return OrderedDict(safe_zip(keys, values))
def _defaultdict_flatten(
dct: defaultdict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
tuple[Callable[[], VT] | None, list[KT]],
tuple[KT, ...],
]:
values, dict_metadata, entries = _dict_flatten(dct)
return values, (dct.default_factory, dict_metadata), entries
def _defaultdict_unflatten(
metadata: tuple[Callable[[], VT], list[KT]],
values: Iterable[VT],
/,
) -> defaultdict[KT, VT]:
default_factory, dict_metadata = metadata
return defaultdict(default_factory, _dict_unflatten(dict_metadata, values))
def _defaultdict_insertion_ordered_flatten(
dct: defaultdict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
tuple[Callable[[], VT] | None, list[KT]],
tuple[KT, ...],
]:
values, dict_metadata, entries = _dict_insertion_ordered_flatten(dct)
return values, (dct.default_factory, dict_metadata), entries
def _defaultdict_insertion_ordered_unflatten(
metadata: tuple[Callable[[], VT], list[KT]],
values: Iterable[VT],
/,
) -> defaultdict[KT, VT]:
default_factory, dict_metadata = metadata
return defaultdict(default_factory, _dict_insertion_ordered_unflatten(dict_metadata, values))
def _deque_flatten(deq: deque[T], /) -> tuple[deque[T], int | None]:
return deq, deq.maxlen
def _deque_unflatten(maxlen: int | None, children: Iterable[T], /) -> deque[T]:
return deque(children, maxlen=maxlen)
def _namedtuple_flatten(tup: NamedTuple[T], /) -> tuple[tuple[T, ...], type[NamedTuple[T]]]: # type: ignore[type-arg]
return tup, type(tup)
# pylint: disable-next=line-too-long
def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T], /) -> NamedTuple[T]: # type: ignore[type-arg]
return cls(*children) # type: ignore[call-overload]
def _structseq_flatten(seq: StructSequence[T], /) -> tuple[tuple[T, ...], type[StructSequence[T]]]:
return seq, type(seq)
def _structseq_unflatten(
cls: type[StructSequence[T]],
children: Iterable[T],
/,
) -> StructSequence[T]:
return cls(children)
_NODETYPE_REGISTRY: dict[type | tuple[str, type], PyTreeNodeRegistryEntry] = {
type(None): PyTreeNodeRegistryEntry(
type(None), # type: ignore[arg-type]
_none_flatten,
_none_unflatten,
path_entry_type=PyTreeEntry,
kind=PyTreeKind.NONE,
),
tuple: PyTreeNodeRegistryEntry(
tuple,
_tuple_flatten,
_tuple_unflatten,
path_entry_type=SequenceEntry,
kind=PyTreeKind.TUPLE,
),
list: PyTreeNodeRegistryEntry(
list,
_list_flatten,
_list_unflatten,
path_entry_type=SequenceEntry,
kind=PyTreeKind.LIST,
),
dict: PyTreeNodeRegistryEntry(
dict,
_dict_flatten,
_dict_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DICT,
),
namedtuple: PyTreeNodeRegistryEntry( # type: ignore[dict-item] # noqa: PYI024
namedtuple, # type: ignore[arg-type] # noqa: PYI024
_namedtuple_flatten,
_namedtuple_unflatten,
path_entry_type=NamedTupleEntry,
kind=PyTreeKind.NAMEDTUPLE,
),
OrderedDict: PyTreeNodeRegistryEntry(
OrderedDict,
_ordereddict_flatten,
_ordereddict_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.ORDEREDDICT,
),
defaultdict: PyTreeNodeRegistryEntry(
defaultdict,
_defaultdict_flatten,
_defaultdict_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DEFAULTDICT,
),
deque: PyTreeNodeRegistryEntry(
deque,
_deque_flatten,
_deque_unflatten,
path_entry_type=SequenceEntry,
kind=PyTreeKind.DEQUE,
),
StructSequence: PyTreeNodeRegistryEntry(
StructSequence,
_structseq_flatten,
_structseq_unflatten,
path_entry_type=StructSequenceEntry,
kind=PyTreeKind.STRUCTSEQUENCE,
),
}
_DICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry(
dict,
_dict_insertion_ordered_flatten,
_dict_insertion_ordered_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DICT,
)
_DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry(
defaultdict,
_defaultdict_insertion_ordered_flatten,
_defaultdict_insertion_ordered_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DEFAULTDICT,
)
@@ -0,0 +1,55 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The :mod:`optree.treespec` namespace contains constructors for class :class:`optree.PyTreeSpec`.
>>> import optree.treespec as treespec
>>> treespec.leaf()
PyTreeSpec(*)
>>> treespec.none()
PyTreeSpec(None)
>>> treespec.dict({'a': treespec.leaf(), 'b': treespec.leaf()})
PyTreeSpec({'a': *, 'b': *})
.. versionadded:: 0.14.1
"""
from __future__ import annotations
from optree.ops import treespec_defaultdict as defaultdict
from optree.ops import treespec_deque as deque
from optree.ops import treespec_dict as dict # pylint: disable=redefined-builtin
from optree.ops import treespec_from_collection as from_collection
from optree.ops import treespec_leaf as leaf
from optree.ops import treespec_list as list # pylint: disable=redefined-builtin
from optree.ops import treespec_namedtuple as namedtuple
from optree.ops import treespec_none as none
from optree.ops import treespec_ordereddict as ordereddict
from optree.ops import treespec_structseq as structseq
from optree.ops import treespec_tuple as tuple # pylint: disable=redefined-builtin
__all__ = [
'leaf',
'none',
'tuple',
'list',
'dict',
'namedtuple',
'ordereddict',
'defaultdict',
'deque',
'structseq',
'from_collection',
]
@@ -0,0 +1,582 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Typing utilities for OpTree."""
from __future__ import annotations
import abc
import functools
import platform
import sys
import threading
import types
from builtins import dict as Dict # noqa: N812
from builtins import list as List # noqa: N812
from builtins import tuple as Tuple # noqa: N812
from collections import OrderedDict
from collections import defaultdict as DefaultDict # noqa: N812
from collections import deque as Deque # noqa: N812
from collections.abc import (
Collection,
Hashable,
ItemsView,
Iterable,
Iterator,
KeysView,
Sequence,
ValuesView,
)
from typing import (
Any,
Callable,
ClassVar,
Final,
ForwardRef,
Generic,
Optional,
Protocol,
TypeVar,
Union,
final,
get_origin,
runtime_checkable,
)
from typing_extensions import (
NamedTuple, # Generic NamedTuple: Python 3.11+
Never, # Python 3.11+
ParamSpec, # Python 3.10+
Self, # Python 3.11+
TypeAlias, # Python 3.10+
TypeAliasType, # Python 3.12+
)
from weakref import WeakKeyDictionary
import optree._C as _C
from optree._C import PyTreeKind, PyTreeSpec
from optree.accessors import (
AutoEntry,
DataclassEntry,
FlattenedEntry,
GetAttrEntry,
GetItemEntry,
MappingEntry,
NamedTupleEntry,
PyTreeAccessor,
PyTreeEntry,
SequenceEntry,
StructSequenceEntry,
)
__all__ = [
'PyTreeSpec',
'PyTreeDef',
'PyTreeKind',
'PyTree',
'PyTreeTypeVar',
'CustomTreeNode',
'Children',
'MetaData',
'FlattenFunc',
'UnflattenFunc',
'PyTreeEntry',
'GetItemEntry',
'GetAttrEntry',
'FlattenedEntry',
'AutoEntry',
'SequenceEntry',
'MappingEntry',
'NamedTupleEntry',
'StructSequenceEntry',
'DataclassEntry',
'PyTreeAccessor',
'is_namedtuple',
'is_namedtuple_class',
'is_namedtuple_instance',
'namedtuple_fields',
'is_structseq',
'is_structseq_class',
'is_structseq_instance',
'structseq_fields',
'T',
'S',
'U',
'KT',
'VT',
'P',
'F',
'Iterable',
'Sequence',
'Tuple',
'List',
'Dict',
'NamedTuple',
'OrderedDict',
'DefaultDict',
'Deque',
'StructSequence',
]
PyTreeDef: TypeAlias = PyTreeSpec # alias
T = TypeVar('T')
S = TypeVar('S')
U = TypeVar('U')
KT = TypeVar('KT')
VT = TypeVar('VT')
P = ParamSpec('P')
F = TypeVar('F', bound=Callable[..., Any])
Children: TypeAlias = Iterable[T]
MetaData: TypeAlias = Optional[Hashable]
@runtime_checkable
class CustomTreeNode(Protocol[T]): # pylint: disable=too-few-public-methods
"""The abstract base class for custom pytree nodes."""
def __tree_flatten__(
self,
/,
) -> (
# Use `range(num_children)` as path entries
tuple[Children[T], MetaData]
|
# With optionally implemented path entries
tuple[Children[T], MetaData, Iterable[Any] | None]
):
"""Flatten the custom pytree node into children and metadata."""
@classmethod
def __tree_unflatten__(cls, metadata: MetaData, children: Children[T], /) -> Self:
"""Unflatten the children and metadata into the custom pytree node."""
_UnionType = type(Union[int, str])
try: # pragma: no cover
from typing import _tp_cache # type: ignore[attr-defined] # pylint: disable=ungrouped-imports
except ImportError: # pragma: no cover
def _tp_cache(func: Callable[P, T], /) -> Callable[P, T]:
cached = functools.lru_cache(func)
@functools.wraps(func)
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
try:
return cached(*args, **kwargs) # type: ignore[arg-type]
except TypeError:
# All real errors (not unhashable args) are raised below.
return func(*args, **kwargs)
return inner
class PyTree(Generic[T]): # pragma: no cover
"""Generic PyTree type.
>>> import torch
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree # doctest: +IGNORE_WHITESPACE
typing.Union[torch.Tensor,
tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
list[ForwardRef('PyTree[torch.Tensor]')],
dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
collections.deque[ForwardRef('PyTree[torch.Tensor]')],
optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]
"""
__slots__: ClassVar[tuple[()]] = ()
__instances__: ClassVar[
WeakKeyDictionary[
TypeAliasType,
tuple[type | TypeAliasType, str | None],
]
] = WeakKeyDictionary()
__instance_lock__: ClassVar[threading.Lock] = threading.Lock()
@_tp_cache
def __class_getitem__( # noqa: C901 # pylint: disable=too-many-branches
cls,
item: (
type[T]
| TypeAliasType
| tuple[type[T] | TypeAliasType]
| tuple[type[T] | TypeAliasType, str | None]
),
) -> TypeAliasType:
"""Instantiate a PyTree type with the given type."""
if not isinstance(item, tuple):
item = (item, None)
if len(item) == 1:
item = (item[0], None)
elif len(item) != 2:
raise TypeError(
f'{cls.__name__}[...] only supports a tuple of 2 items, '
f'a parameter and a string of type name, got {item!r}.',
)
param, name = item
if name is not None and not isinstance(name, str):
raise TypeError(
f'{cls.__name__}[...] only supports a tuple of 2 items, '
f'a parameter and a string of type name, got {item!r}.',
)
if isinstance(param, _UnionType) and get_origin(param) is Union: # type: ignore[unreachable]
with cls.__instance_lock__: # type: ignore[unreachable]
try:
if param in cls.__instances__:
return param # PyTree[PyTree[T]] -> PyTree[T]
except TypeError:
pass # non-hashable type
if name is not None:
recurse_ref = ForwardRef(name)
elif isinstance(param, TypeVar):
recurse_ref = ForwardRef(f'{cls.__name__}[{param.__name__}]') # type: ignore[unreachable]
elif isinstance(param, type):
if param.__module__ == 'builtins':
typename = param.__qualname__
else:
try:
typename = f'{param.__module__}.{param.__qualname__}'
except AttributeError:
typename = f'{param.__module__}.{param.__name__}'
recurse_ref = ForwardRef(f'{cls.__name__}[{typename}]')
else:
recurse_ref = ForwardRef(f'{cls.__name__}[{param!r}]')
pytree_alias = Union[
param, # type: ignore[valid-type]
Tuple[recurse_ref, ...], # type: ignore[valid-type] # Tuple, NamedTuple, PyStructSequence
List[recurse_ref], # type: ignore[valid-type]
Dict[Any, recurse_ref], # type: ignore[valid-type] # Dict, OrderedDict, DefaultDict
Deque[recurse_ref], # type: ignore[valid-type]
CustomTreeNode[recurse_ref], # type: ignore[valid-type]
]
with cls.__instance_lock__:
cls.__instances__[pytree_alias] = (param, name) # type: ignore[index]
return pytree_alias # type: ignore[return-value]
def __new__(cls, /) -> Never: # pylint: disable=arguments-differ
"""Prohibit instantiation."""
raise TypeError('Cannot instantiate special typing classes.')
def __init_subclass__(cls, /, *args: Any, **kwargs: Any) -> Never:
"""Prohibit subclassing."""
raise TypeError('Cannot subclass special typing classes.')
def __getitem__(self, key: Any, /) -> PyTree[T] | T:
"""Emulate collection-like behavior."""
raise NotImplementedError
def __getattr__(self, name: str, /) -> PyTree[T] | T:
"""Emulate dataclass-like behavior."""
raise NotImplementedError
def __contains__(self, key: Any, /) -> bool:
"""Emulate collection-like behavior."""
raise NotImplementedError
def __len__(self, /) -> int:
"""Emulate collection-like behavior."""
raise NotImplementedError
def __iter__(self, /) -> Iterator[PyTree[T] | T | Any]:
"""Emulate collection-like behavior."""
raise NotImplementedError
def index(self, key: Any, /) -> int:
"""Emulate sequence-like behavior."""
raise NotImplementedError
def count(self, key: Any, /) -> int:
"""Emulate sequence-like behavior."""
raise NotImplementedError
def get(self, key: Any, /, default: S | None = None) -> PyTree[T] | T | S | None:
"""Emulate mapping-like behavior."""
raise NotImplementedError
def keys(self, /) -> KeysView[Any]:
"""Emulate mapping-like behavior."""
raise NotImplementedError
def values(self, /) -> ValuesView[PyTree[T] | T]:
"""Emulate mapping-like behavior."""
raise NotImplementedError
def items(self, /) -> ItemsView[Any, PyTree[T] | T]:
"""Emulate mapping-like behavior."""
raise NotImplementedError
# pylint: disable-next=too-few-public-methods
class PyTreeTypeVar: # pragma: no cover
"""Type variable for PyTree.
>>> import torch
>>> TensorTree = PyTreeTypeVar('TensorTree', torch.Tensor)
>>> TensorTree # doctest: +IGNORE_WHITESPACE
typing.Union[torch.Tensor,
tuple[ForwardRef('TensorTree'), ...],
list[ForwardRef('TensorTree')],
dict[typing.Any, ForwardRef('TensorTree')],
collections.deque[ForwardRef('TensorTree')],
optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]
"""
@_tp_cache
def __new__(cls, /, name: str, param: type | TypeAliasType) -> TypeAliasType: # type: ignore[misc]
"""Instantiate a PyTree type variable with the given name and parameter."""
if not isinstance(name, str):
raise TypeError(f'{cls.__name__} only supports a string of type name, got {name!r}.')
return PyTree[param, name] # type: ignore[misc,valid-type]
def __init_subclass__(cls, /, *args: Any, **kwargs: Any) -> Never:
"""Prohibit subclassing."""
raise TypeError('Cannot subclass special typing classes.')
class FlattenFunc(Protocol[T]): # pylint: disable=too-few-public-methods
"""The type stub class for flatten functions."""
@abc.abstractmethod
def __call__(
self,
container: Collection[T],
/,
) -> tuple[Children[T], MetaData] | tuple[Children[T], MetaData, Iterable[Any] | None]:
"""Flatten the container into children and metadata."""
class UnflattenFunc(Protocol[T]): # pylint: disable=too-few-public-methods
"""The type stub class for unflatten functions."""
@abc.abstractmethod
def __call__(self, metadata: MetaData, children: Children[T], /) -> Collection[T]:
"""Unflatten the children and metadata back into the container."""
def _override_with_(
cxx_implementation: Callable[P, T],
/,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator to override the Python implementation with the C++ implementation.
>>> @_override_with_(any)
... def my_any(iterable):
... for elem in iterable:
... if elem:
... return True
... return False
...
>>> my_any([False, False, True, False, False, True]) # run at C speed
True
"""
def wrapper(python_implementation: Callable[P, T], /) -> Callable[P, T]:
@functools.wraps(python_implementation)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
return cxx_implementation(*args, **kwargs)
wrapped.__cxx_implementation__ = cxx_implementation # type: ignore[attr-defined]
wrapped.__python_implementation__ = python_implementation # type: ignore[attr-defined]
return wrapped
return wrapper
@_override_with_(_C.is_namedtuple)
def is_namedtuple(obj: object | type, /) -> bool:
"""Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
cls = obj if isinstance(obj, type) else type(obj)
return is_namedtuple_class(cls)
@_override_with_(_C.is_namedtuple_instance)
def is_namedtuple_instance(obj: object, /) -> bool:
"""Return whether the object is an instance of namedtuple."""
return is_namedtuple_class(type(obj))
@_override_with_(_C.is_namedtuple_class)
def is_namedtuple_class(cls: type, /) -> bool:
"""Return whether the class is a subclass of namedtuple."""
return (
isinstance(cls, type)
and issubclass(cls, tuple)
and isinstance(getattr(cls, '_fields', None), tuple)
# pylint: disable-next=unidiomatic-typecheck
and all(type(field) is str for field in cls._fields) # type: ignore[attr-defined]
and callable(getattr(cls, '_make', None))
and callable(getattr(cls, '_asdict', None))
)
@_override_with_(_C.namedtuple_fields)
def namedtuple_fields(obj: tuple | type[tuple], /) -> tuple[str, ...]:
"""Return the field names of a namedtuple."""
if isinstance(obj, type):
cls = obj
if not is_namedtuple_class(cls):
raise TypeError(f'Expected a collections.namedtuple type, got {cls!r}.')
else:
cls = type(obj)
if not is_namedtuple_class(cls):
raise TypeError(f'Expected an instance of collections.namedtuple type, got {obj!r}.')
return cls._fields # type: ignore[attr-defined]
_T_co = TypeVar('_T_co', covariant=True)
class StructSequenceMeta(type):
"""The metaclass for PyStructSequence stub type."""
def __subclasscheck__(cls, subclass: type, /) -> bool:
"""Return whether the class is a PyStructSequence type.
>>> import time
>>> issubclass(time.struct_time, StructSequence)
True
>>> class MyTuple(tuple):
... n_fields = 2
... n_sequence_fields = 2
... n_unnamed_fields = 0
>>> issubclass(MyTuple, StructSequence)
False
"""
return is_structseq_class(subclass)
def __instancecheck__(cls, instance: Any, /) -> bool:
"""Return whether the object is a PyStructSequence instance.
>>> import sys
>>> isinstance(sys.float_info, StructSequence)
True
>>> isinstance((1, 2), StructSequence)
False
"""
return is_structseq_instance(instance)
# Reference: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi
# This is an internal CPython type that is like, but subtly different from a NamedTuple.
# `StructSequence` classes are unsubclassable, so are all decorated with `@final`.
# pylint: disable-next=invalid-name,missing-class-docstring
@final
class StructSequence(tuple[_T_co, ...], metaclass=StructSequenceMeta):
"""A generic type stub for CPython's ``PyStructSequence`` type."""
__slots__: ClassVar[tuple[()]] = ()
n_fields: Final[ClassVar[int]] # type: ignore[misc] # pylint: disable=invalid-name
n_sequence_fields: Final[ClassVar[int]] # type: ignore[misc] # pylint: disable=invalid-name
n_unnamed_fields: Final[ClassVar[int]] # type: ignore[misc] # pylint: disable=invalid-name
def __init_subclass__(cls, /) -> Never:
"""Prohibit subclassing."""
raise TypeError("type 'StructSequence' is not an acceptable base type")
# pylint: disable-next=unused-argument,redefined-builtin
def __new__(cls, /, sequence: Iterable[_T_co], dict: dict[str, Any] = ...) -> Self:
"""Create a new :class:`StructSequence` instance."""
raise NotImplementedError
structseq: TypeAlias = StructSequence # noqa: PYI042 # pylint: disable=invalid-name
del StructSequenceMeta
@_override_with_(_C.is_structseq)
def is_structseq(obj: object | type, /) -> bool:
"""Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
cls = obj if isinstance(obj, type) else type(obj)
return is_structseq_class(cls)
@_override_with_(_C.is_structseq_instance)
def is_structseq_instance(obj: object, /) -> bool:
"""Return whether the object is an instance of PyStructSequence."""
return is_structseq_class(type(obj))
# Set if the type allows subclassing (see CPython's Include/object.h)
Py_TPFLAGS_BASETYPE: int = _C.Py_TPFLAGS_BASETYPE # (1UL << 10) # pylint: disable=invalid-name
@_override_with_(_C.is_structseq_class)
def is_structseq_class(cls: type, /) -> bool:
"""Return whether the class is a class of PyStructSequence."""
if (
isinstance(cls, type)
# Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
and cls.__bases__ == (tuple,)
# Check PyStructSequence members
and isinstance(getattr(cls, 'n_fields', None), int)
and isinstance(getattr(cls, 'n_sequence_fields', None), int)
and isinstance(getattr(cls, 'n_unnamed_fields', None), int)
):
# Check the type does not allow subclassing
if platform.python_implementation() == 'PyPy': # pragma: pypy cover
try:
types.new_class('subclass', bases=(cls,))
except (AssertionError, TypeError):
return True
return False
return not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) # pragma: pypy no cover
return False
# pylint: disable-next=line-too-long
StructSequenceFieldType: type[types.MemberDescriptorType] = type(type(sys.version_info).major) # type: ignore[assignment]
@_override_with_(_C.structseq_fields)
def structseq_fields(obj: tuple | type[tuple], /) -> tuple[str, ...]:
"""Return the field names of a PyStructSequence."""
if isinstance(obj, type):
cls = obj
if not is_structseq_class(cls):
raise TypeError(f'Expected a PyStructSequence type, got {cls!r}.')
else:
cls = type(obj)
if not is_structseq_class(cls):
raise TypeError(f'Expected an instance of PyStructSequence type, got {obj!r}.')
if platform.python_implementation() == 'PyPy': # pragma: pypy cover
indices_by_name = {
name: member.index # type: ignore[attr-defined]
for name, member in vars(cls).items()
if isinstance(member, StructSequenceFieldType)
}
fields = sorted(indices_by_name, key=indices_by_name.get) # type: ignore[arg-type]
else: # pragma: pypy no cover
fields = [
name
for name, member in vars(cls).items()
if isinstance(member, StructSequenceFieldType)
]
return tuple(fields[: cls.n_sequence_fields]) # type: ignore[attr-defined]
del _tp_cache
@@ -0,0 +1,117 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for OpTree."""
from __future__ import annotations
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any, Callable, overload
if TYPE_CHECKING:
from optree.typing import S, T, U
def total_order_sorted(
iterable: Iterable[T],
/,
*,
key: Callable[[T], Any] | None = None,
reverse: bool = False,
) -> list[T]:
"""Sort an iterable in a total order.
This is useful for sorting objects that are not comparable, e.g., dictionaries with different
types of keys.
"""
sequence = list(iterable)
try:
# Sort directly if possible
return sorted(sequence, key=key, reverse=reverse) # type: ignore[type-var,arg-type]
except TypeError:
if key is None:
def key_fn(x: T) -> tuple[str, Any]:
return (f'{x.__class__.__module__}.{x.__class__.__qualname__}', x)
else:
def key_fn(x: T) -> tuple[str, Any]:
y = key(x)
return (f'{y.__class__.__module__}.{y.__class__.__qualname__}', y)
try:
# Add `{obj.__class__.__module__}.{obj.__class__.__qualname__}` to the key order to make
# it sortable between different types (e.g., `int` vs. `str`)
return sorted(sequence, key=key_fn, reverse=reverse)
except TypeError: # cannot sort the keys (e.g., user-defined types)
return sequence # fallback to original order
@overload
def safe_zip(
iter1: Iterable[T],
/,
) -> zip[tuple[T]]: ...
@overload
def safe_zip(
iter1: Iterable[T],
iter2: Iterable[S],
/,
) -> zip[tuple[T, S]]: ...
@overload
def safe_zip(
iter1: Iterable[T],
iter2: Iterable[S],
iter3: Iterable[U],
/,
) -> zip[tuple[T, S, U]]: ...
@overload
def safe_zip(
iter1: Iterable[Any],
iter2: Iterable[Any],
iter3: Iterable[Any],
iter4: Iterable[Any],
/,
*iters: Iterable[Any],
) -> zip[tuple[Any, ...]]: ...
def safe_zip(*args: Iterable[Any]) -> zip[tuple[Any, ...]]:
"""Strict zip that requires all arguments to be the same length."""
seqs = [arg if isinstance(arg, Sequence) else list(arg) for arg in args]
if len(set(map(len, seqs))) > 1:
raise ValueError(f'length mismatch: {list(map(len, seqs))}')
return zip(*seqs)
def unzip2(xys: Iterable[tuple[T, S]], /) -> tuple[tuple[T, ...], tuple[S, ...]]:
"""Unzip sequence of length-2 tuples into two tuples."""
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
# is too permissive about inputs, and does not guarantee a length-2 output.
# For example, for empty dict: tuple(zip(*{}.items())) -> ()
xs = []
ys = []
for x, y in xys:
xs.append(x)
ys.append(y)
return tuple(xs), tuple(ys)
@@ -0,0 +1,60 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""OpTree: Optimized PyTree Utilities."""
# pylint: disable=invalid-name
__version__ = '0.18.0'
__license__ = 'Apache-2.0'
__author__ = 'OpTree Contributors'
__release__ = True
if not __release__:
import os
import subprocess
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
try:
prefix, sep, suffix = (
subprocess.check_output( # noqa: S603
[ # noqa: S607
'git',
f'--git-dir={os.path.join(root_dir, ".git")}',
'describe',
'--abbrev=7',
],
cwd=root_dir,
stderr=subprocess.DEVNULL,
text=True,
encoding='utf-8',
)
.strip()
.lstrip('v')
.replace('-', '.dev', 1)
.replace('-', '+', 1)
.partition('.dev')
)
if sep:
version_prefix, dot, version_tail = prefix.rpartition('.')
prefix = f'{version_prefix}{dot}{int(version_tail) + 1}'
__version__ = f'{prefix}{sep}{suffix}'
del version_prefix, dot, version_tail
else:
__version__ = prefix
del prefix, sep, suffix
except (OSError, subprocess.CalledProcessError):
pass
del os, subprocess, root_dir