use pre-classified fields in visitor generic_visit methods

Both NodeVisitor.generic_visit and NodeTransformer.generic_visit now
iterate _node_fields and _node_list_fields directly instead of
calling iter_child_nodes/iter_fields.  This avoids generator overhead
and isinstance checks on every visited node.
This commit is contained in:
tobymao 2026-03-21 23:36:26 -07:00
parent 6a5f1dfb75
commit 4b6c3d6f4f
No known key found for this signature in database
GPG Key ID: 5B3B30F3BACB48DC

View File

@ -43,8 +43,18 @@ class NodeVisitor:
def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Called if no explicit visitor function exists for a node."""
for child_node in node.iter_child_nodes():
self.visit(child_node, *args, **kwargs)
d = node.__dict__
node_type = type(node)
visit = self.visit
for name in node_type._node_fields:
child = d.get(name)
if child is not None:
visit(child, *args, **kwargs)
for name in node_type._node_list_fields:
children = d.get(name)
if children:
for child in children:
visit(child, *args, **kwargs)
class NodeTransformer(NodeVisitor):
@ -59,25 +69,36 @@ class NodeTransformer(NodeVisitor):
"""
def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> Node:
for field, old_value in node.iter_fields():
if isinstance(old_value, list):
new_values = []
for value in old_value:
if isinstance(value, Node):
value = self.visit(value, *args, **kwargs)
if value is None:
continue
elif not isinstance(value, Node):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, Node):
new_node = self.visit(old_value, *args, **kwargs)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
d = node.__dict__
node_type = type(node)
visit = self.visit
for field in node_type._node_fields:
old_value = d.get(field)
if old_value is None:
continue
new_node = visit(old_value, *args, **kwargs)
if new_node is None:
d.pop(field, None)
elif new_node is not old_value:
d[field] = new_node
for field in node_type._node_list_fields:
old_value = d.get(field)
if not old_value:
continue
new_values: list[t.Any] = []
for value in old_value:
if isinstance(value, Node):
value = visit(value, *args, **kwargs)
if value is None:
continue
elif not isinstance(value, Node):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
return node
def visit_list(self, node: Node, *args: t.Any, **kwargs: t.Any) -> list[Node]: