Permalink
Newer
100644
1737 lines (1507 sloc)
59.3 KB
1
"""
2
ast
3
~~~
4
5
The `ast` module helps Python applications to process trees of the Python
6
abstract syntax grammar. The abstract syntax itself might change with
7
each Python release; this module helps to find out programmatically what
8
the current grammar looks like and allows modifications of it.
9
10
An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
11
a flag to the `compile()` builtin function or by using the `parse()`
12
function from this module. The result will be a tree of objects whose
13
classes all inherit from `ast.AST`.
14
15
A modified abstract syntax tree can be compiled into a Python code object
16
using the built-in `compile()` function.
17
18
Additionally various helper functions are provided that make working with
19
the trees simpler. The main intention of the helper functions and this
20
module in general is to provide an easy to use interface for libraries
21
that work tightly with the python syntax (template engines for example).
22
23
24
:copyright: Copyright 2008 by Armin Ronacher.
25
:license: Python License.
26
"""
28
from _ast import *
29
from contextlib import contextmanager, nullcontext
30
from enum import IntEnum, auto, _simple_enum
33
def parse(source, filename='<unknown>', mode='exec', *,
34
type_comments=False, feature_version=None):
36
Parse the source into an AST node.
37
Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
38
Pass type_comments=True to get back type comments where the syntax allows.
40
flags = PyCF_ONLY_AST
41
if type_comments:
42
flags |= PyCF_TYPE_COMMENTS
43
if feature_version is None:
44
feature_version = -1
45
elif isinstance(feature_version, tuple):
46
major, minor = feature_version # Should be a 2-tuple.
47
if major != 3:
48
raise ValueError(f"Unsupported major version: {major}")
49
feature_version = minor
50
# Else it should be an int giving the minor version for 3.x.
51
return compile(source, filename, mode, flags,
52
_feature_version=feature_version)
53
54
55
def literal_eval(node_or_string):
56
"""
57
Evaluate an expression node or a string containing only a Python
58
expression. The string or node provided may only consist of the following
59
Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
60
sets, booleans, and None.
61
62
Caution: A complex expression can overflow the C stack and cause a crash.
63
"""
64
if isinstance(node_or_string, str):
65
node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval')
66
if isinstance(node_or_string, Expression):
67
node_or_string = node_or_string.body
68
def _raise_malformed_node(node):
69
msg = "malformed node or string"
70
if lno := getattr(node, 'lineno', None):
71
msg += f' on line {lno}'
72
raise ValueError(msg + f': {node!r}')
74
if not isinstance(node, Constant) or type(node.value) not in (int, float, complex):
75
_raise_malformed_node(node)
76
return node.value
77
def _convert_signed_num(node):
78
if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
79
operand = _convert_num(node.operand)
80
if isinstance(node.op, UAdd):
81
return + operand
82
else:
83
return - operand
84
return _convert_num(node)
85
def _convert(node):
88
elif isinstance(node, Tuple):
89
return tuple(map(_convert, node.elts))
90
elif isinstance(node, List):
91
return list(map(_convert, node.elts))
92
elif isinstance(node, Set):
93
return set(map(_convert, node.elts))
94
elif (isinstance(node, Call) and isinstance(node.func, Name) and
95
node.func.id == 'set' and node.args == node.keywords == []):
96
return set()
97
elif isinstance(node, Dict):
98
if len(node.keys) != len(node.values):
99
_raise_malformed_node(node)
100
return dict(zip(map(_convert, node.keys),
101
map(_convert, node.values)))
103
left = _convert_signed_num(node.left)
104
right = _convert_num(node.right)
105
if isinstance(left, (int, float)) and isinstance(right, complex):
106
if isinstance(node.op, Add):
107
return left + right
108
else:
109
return left - right
110
return _convert_signed_num(node)
111
return _convert(node_or_string)
112
113
114
def dump(node, annotate_fields=True, include_attributes=False, *, indent=None):
116
Return a formatted dump of the tree in node. This is mainly useful for
117
debugging purposes. If annotate_fields is true (by default),
118
the returned string will show the names and the values for fields.
119
If annotate_fields is false, the result string will be more compact by
120
omitting unambiguous field names. Attributes such as line
121
numbers and column offsets are not dumped by default. If this is wanted,
122
include_attributes can be set to true. If indent is a non-negative
123
integer or string, then the tree will be pretty-printed with that indent
124
level. None (the default) selects the single line representation.
126
def _format(node, level=0):
127
if indent is not None:
128
level += 1
129
prefix = '\n' + indent * level
130
sep = ',\n' + indent * level
131
else:
132
prefix = ''
133
sep = ', '
134
if isinstance(node, AST):
135
cls = type(node)
137
allsimple = True
138
keywords = annotate_fields
139
for name in node._fields:
141
value = getattr(node, name)
142
except AttributeError:
143
keywords = True
144
continue
145
if value is None and getattr(cls, name, ...) is None:
146
keywords = True
147
continue
148
value, simple = _format(value, level)
149
allsimple = allsimple and simple
150
if keywords:
151
args.append('%s=%s' % (name, value))
153
args.append(value)
154
if include_attributes and node._attributes:
155
for name in node._attributes:
157
value = getattr(node, name)
158
except AttributeError:
159
continue
160
if value is None and getattr(cls, name, ...) is None:
161
continue
162
value, simple = _format(value, level)
163
allsimple = allsimple and simple
164
args.append('%s=%s' % (name, value))
165
if allsimple and len(args) <= 3:
166
return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args
167
return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False
168
elif isinstance(node, list):
169
if not node:
170
return '[]', True
171
return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False
172
return repr(node), True
173
174
if not isinstance(node, AST):
175
raise TypeError('expected AST, got %r' % node.__class__.__name__)
176
if indent is not None and not isinstance(indent, str):
177
indent = ' ' * indent
178
return _format(node)[0]
179
180
181
def copy_location(new_node, old_node):
182
"""
183
Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset`
184
attributes) from *old_node* to *new_node* if possible, and return *new_node*.
186
for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset':
187
if attr in old_node._attributes and attr in new_node._attributes:
188
value = getattr(old_node, attr, None)
189
# end_lineno and end_col_offset are optional attributes, and they
190
# should be copied whether the value is None or not.
191
if value is not None or (
192
hasattr(old_node, attr) and attr.startswith("end_")
193
):
194
setattr(new_node, attr, value)
195
return new_node
196
197
198
def fix_missing_locations(node):
199
"""
200
When you compile a node tree with compile(), the compiler expects lineno and
201
col_offset attributes for every node that supports them. This is rather
202
tedious to fill in for generated nodes, so this helper adds these attributes
203
recursively where not already set, by setting them to the values of the
204
parent node. It works recursively starting at *node*.
205
"""
206
def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
207
if 'lineno' in node._attributes:
208
if not hasattr(node, 'lineno'):
209
node.lineno = lineno
210
else:
211
lineno = node.lineno
212
if 'end_lineno' in node._attributes:
213
if getattr(node, 'end_lineno', None) is None:
214
node.end_lineno = end_lineno
215
else:
216
end_lineno = node.end_lineno
217
if 'col_offset' in node._attributes:
218
if not hasattr(node, 'col_offset'):
219
node.col_offset = col_offset
220
else:
221
col_offset = node.col_offset
222
if 'end_col_offset' in node._attributes:
223
if getattr(node, 'end_col_offset', None) is None:
224
node.end_col_offset = end_col_offset
225
else:
226
end_col_offset = node.end_col_offset
227
for child in iter_child_nodes(node):
228
_fix(child, lineno, col_offset, end_lineno, end_col_offset)
229
_fix(node, 1, 0, 1, 0)
230
return node
231
232
233
def increment_lineno(node, n=1):
234
"""
235
Increment the line number and end line number of each node in the tree
236
starting at *node* by *n*. This is useful to "move code" to a different
237
location in a file.
238
"""
239
for child in walk(node):
240
# TypeIgnore is a special case where lineno is not an attribute
241
# but rather a field of the node itself.
242
if isinstance(child, TypeIgnore):
243
child.lineno = getattr(child, 'lineno', 0) + n
244
continue
245
246
if 'lineno' in child._attributes:
247
child.lineno = getattr(child, 'lineno', 0) + n
248
if (
249
"end_lineno" in child._attributes
250
and (end_lineno := getattr(child, "end_lineno", 0)) is not None
251
):
252
child.end_lineno = end_lineno + n
253
return node
254
255
256
def iter_fields(node):
257
"""
258
Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
259
that is present on *node*.
260
"""
261
for field in node._fields:
262
try:
263
yield field, getattr(node, field)
264
except AttributeError:
265
pass
266
267
268
def iter_child_nodes(node):
269
"""
270
Yield all direct child nodes of *node*, that is, all fields that are nodes
271
and all items of fields that are lists of nodes.
272
"""
273
for name, field in iter_fields(node):
274
if isinstance(field, AST):
275
yield field
276
elif isinstance(field, list):
277
for item in field:
278
if isinstance(item, AST):
279
yield item
280
281
282
def get_docstring(node, clean=True):
283
"""
284
Return the docstring for the given node or None if no docstring can
285
be found. If the node provided does not have docstrings a TypeError
286
will be raised.
287
288
If *clean* is `True`, all tabs are expanded to spaces and any whitespace
289
that can be uniformly removed from the second line onwards is removed.
291
if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
292
raise TypeError("%r can't have docstrings" % node.__class__.__name__)
293
if not(node.body and isinstance(node.body[0], Expr)):
294
return None
295
node = node.body[0].value
296
if isinstance(node, Str):
297
text = node.s
298
elif isinstance(node, Constant) and isinstance(node.value, str):
299
text = node.value
300
else:
301
return None
308
def _splitlines_no_ff(source):
309
"""Split a string into lines ignoring form feed and other chars.
310
311
This mimics how the Python parser splits source code.
312
"""
313
idx = 0
314
lines = []
315
next_line = ''
316
while idx < len(source):
317
c = source[idx]
318
next_line += c
319
idx += 1
320
# Keep \r\n together
321
if c == '\r' and idx < len(source) and source[idx] == '\n':
322
next_line += '\n'
323
idx += 1
324
if c in '\r\n':
325
lines.append(next_line)
326
next_line = ''
327
328
if next_line:
329
lines.append(next_line)
330
return lines
331
332
333
def _pad_whitespace(source):
334
r"""Replace all chars except '\f\t' in a line with spaces."""
335
result = ''
336
for c in source:
337
if c in '\f\t':
338
result += c
339
else:
340
result += ' '
341
return result
342
343
344
def get_source_segment(source, node, *, padded=False):
345
"""Get source code segment of the *source* that generated *node*.
346
347
If some location information (`lineno`, `end_lineno`, `col_offset`,
348
or `end_col_offset`) is missing, return None.
349
350
If *padded* is `True`, the first line of a multi-line statement will
351
be padded with spaces to match its original position.
352
"""
353
try:
354
if node.end_lineno is None or node.end_col_offset is None:
355
return None
356
lineno = node.lineno - 1
357
end_lineno = node.end_lineno - 1
358
col_offset = node.col_offset
359
end_col_offset = node.end_col_offset
360
except AttributeError:
361
return None
362
363
lines = _splitlines_no_ff(source)
364
if end_lineno == lineno:
365
return lines[lineno].encode()[col_offset:end_col_offset].decode()
366
367
if padded:
368
padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode())
369
else:
370
padding = ''
371
372
first = padding + lines[lineno].encode()[col_offset:].decode()
373
last = lines[end_lineno].encode()[:end_col_offset].decode()
374
lines = lines[lineno+1:end_lineno]
375
376
lines.insert(0, first)
377
lines.append(last)
378
return ''.join(lines)
379
380
381
def walk(node):
382
"""
383
Recursively yield all descendant nodes in the tree starting at *node*
384
(including *node* itself), in no specified order. This is useful if you
385
only want to modify nodes in place and don't care about the context.
386
"""
387
from collections import deque
388
todo = deque([node])
389
while todo:
390
node = todo.popleft()
391
todo.extend(iter_child_nodes(node))
392
yield node
393
394
395
class NodeVisitor(object):
396
"""
397
A node visitor base class that walks the abstract syntax tree and calls a
398
visitor function for every node found. This function may return a value
399
which is forwarded by the `visit` method.
400
401
This class is meant to be subclassed, with the subclass adding visitor
402
methods.
403
404
Per default the visitor functions for the nodes are ``'visit_'`` +
405
class name of the node. So a `TryFinally` node visit function would
406
be `visit_TryFinally`. This behavior can be changed by overriding
407
the `visit` method. If no visitor function exists for a node
408
(return value `None`) the `generic_visit` visitor is used instead.
409
410
Don't use the `NodeVisitor` if you want to apply changes to nodes during
411
traversing. For this a special visitor exists (`NodeTransformer`) that
412
allows modifications.
413
"""
414
415
def visit(self, node):
416
"""Visit a node."""
417
method = 'visit_' + node.__class__.__name__
418
visitor = getattr(self, method, self.generic_visit)
419
return visitor(node)
420
421
def generic_visit(self, node):
422
"""Called if no explicit visitor function exists for a node."""
423
for field, value in iter_fields(node):
424
if isinstance(value, list):
425
for item in value:
426
if isinstance(item, AST):
427
self.visit(item)
428
elif isinstance(value, AST):
429
self.visit(value)
430
431
def visit_Constant(self, node):
432
value = node.value
433
type_name = _const_node_type_names.get(type(value))
434
if type_name is None:
435
for cls, name in _const_node_type_names.items():
436
if isinstance(value, cls):
437
type_name = name
438
break
439
if type_name is not None:
440
method = 'visit_' + type_name
441
try:
442
visitor = getattr(self, method)
443
except AttributeError:
444
pass
445
else:
446
import warnings
447
warnings.warn(f"{method} is deprecated; add visit_Constant",
448
DeprecationWarning, 2)
449
return visitor(node)
450
return self.generic_visit(node)
451
452
453
class NodeTransformer(NodeVisitor):
454
"""
455
A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
456
allows modification of nodes.
457
458
The `NodeTransformer` will walk the AST and use the return value of the
459
visitor methods to replace or remove the old node. If the return value of
460
the visitor method is ``None``, the node will be removed from its location,
461
otherwise it is replaced with the return value. The return value may be the
462
original node in which case no replacement takes place.
463
464
Here is an example transformer that rewrites all occurrences of name lookups
465
(``foo``) to ``data['foo']``::
466
467
class RewriteName(NodeTransformer):
468
469
def visit_Name(self, node):
470
return Subscript(
471
value=Name(id='data', ctx=Load()),
473
ctx=node.ctx
475
476
Keep in mind that if the node you're operating on has child nodes you must
477
either transform the child nodes yourself or call the :meth:`generic_visit`
478
method for the node first.
479
480
For nodes that were part of a collection of statements (that applies to all
481
statement nodes), the visitor may also return a list of nodes rather than
482
just a single node.
483
484
Usually you use the transformer like this::
485
486
node = YourTransformer().visit(node)
487
"""
488
489
def generic_visit(self, node):
490
for field, old_value in iter_fields(node):
491
if isinstance(old_value, list):
492
new_values = []
493
for value in old_value:
494
if isinstance(value, AST):
495
value = self.visit(value)
496
if value is None:
497
continue
498
elif not isinstance(value, AST):
499
new_values.extend(value)
500
continue
501
new_values.append(value)
502
old_value[:] = new_values
503
elif isinstance(old_value, AST):
504
new_node = self.visit(old_value)
505
if new_node is None:
506
delattr(node, field)
507
else:
508
setattr(node, field, new_node)
509
return node
512
# If the ast module is loaded more than once, only add deprecated methods once
513
if not hasattr(Constant, 'n'):
514
# The following code is for backward compatibility.
515
# It will be removed in future.
517
def _getter(self):
518
"""Deprecated. Use value instead."""
519
return self.value
521
def _setter(self, value):
522
self.value = value
524
Constant.n = property(_getter, _setter)
525
Constant.s = property(_getter, _setter)
526
527
class _ABC(type):
528
529
def __init__(cls, *args):
530
cls.__doc__ = """Deprecated AST node class. Use ast.Constant instead"""
531
532
def __instancecheck__(cls, inst):
533
if not isinstance(inst, Constant):
534
return False
535
if cls in _const_types:
536
try:
537
value = inst.value
538
except AttributeError:
539
return False
540
else:
541
return (
542
isinstance(value, _const_types[cls]) and
543
not isinstance(value, _const_types_not.get(cls, ()))
544
)
545
return type.__instancecheck__(cls, inst)
546
547
def _new(cls, *args, **kwargs):
548
for key in kwargs:
549
if key not in cls._fields:
550
# arbitrary keyword arguments are accepted
551
continue
552
pos = cls._fields.index(key)
553
if pos < len(args):
554
raise TypeError(f"{cls.__name__} got multiple values for argument {key!r}")
555
if cls in _const_types:
556
return Constant(*args, **kwargs)
557
return Constant.__new__(cls, *args, **kwargs)
558
559
class Num(Constant, metaclass=_ABC):
560
_fields = ('n',)
561
__new__ = _new
562
563
class Str(Constant, metaclass=_ABC):
564
_fields = ('s',)
565
__new__ = _new
566
567
class Bytes(Constant, metaclass=_ABC):
568
_fields = ('s',)
569
__new__ = _new
570
571
class NameConstant(Constant, metaclass=_ABC):
572
__new__ = _new
573
574
class Ellipsis(Constant, metaclass=_ABC):
575
_fields = ()
576
577
def __new__(cls, *args, **kwargs):
578
if cls is Ellipsis:
579
return Constant(..., *args, **kwargs)
580
return Constant.__new__(cls, *args, **kwargs)
581
582
_const_types = {
583
Num: (int, float, complex),
584
Str: (str,),
585
Bytes: (bytes,),
586
NameConstant: (type(None), bool),
587
Ellipsis: (type(...),),
588
}
589
_const_types_not = {
590
Num: (bool,),
591
}
593
_const_node_type_names = {
594
bool: 'NameConstant', # should be before int
595
type(None): 'NameConstant',
596
int: 'Num',
597
float: 'Num',
598
complex: 'Num',
599
str: 'Str',
600
bytes: 'Bytes',
601
type(...): 'Ellipsis',
602
}
604
class slice(AST):
605
"""Deprecated AST node class."""
606
607
class Index(slice):
608
"""Deprecated AST node class. Use the index value directly instead."""
609
def __new__(cls, value, **kwargs):
610
return value
611
612
class ExtSlice(slice):
613
"""Deprecated AST node class. Use ast.Tuple instead."""
614
def __new__(cls, dims=(), **kwargs):
615
return Tuple(list(dims), Load(), **kwargs)
616
617
# If the ast module is loaded more than once, only add deprecated methods once
618
if not hasattr(Tuple, 'dims'):
619
# The following code is for backward compatibility.
620
# It will be removed in future.
622
def _dims_getter(self):
623
"""Deprecated. Use elts instead."""
624
return self.elts
626
def _dims_setter(self, value):
627
self.elts = value
628
629
Tuple.dims = property(_dims_getter, _dims_setter)
631
class Suite(mod):
632
"""Deprecated AST node class. Unused in Python 3."""
633
634
class AugLoad(expr_context):
635
"""Deprecated AST node class. Unused in Python 3."""
636
637
class AugStore(expr_context):
638
"""Deprecated AST node class. Unused in Python 3."""
639
640
class Param(expr_context):
641
"""Deprecated AST node class. Unused in Python 3."""
642
644
# Large float and imaginary literals get turned into infinities in the AST.
645
# We unparse those infinities to INFSTR.
646
_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
647
648
@_simple_enum(IntEnum)
649
class _Precedence:
650
"""Precedence table that originated from python grammar."""
651
652
NAMED_EXPR = auto() # <target> := <expr1>
653
TUPLE = auto() # <expr1>, <expr2>
654
YIELD = auto() # 'yield', 'yield from'
655
TEST = auto() # 'if'-'else', 'lambda'
656
OR = auto() # 'or'
657
AND = auto() # 'and'
658
NOT = auto() # 'not'
659
CMP = auto() # '<', '>', '==', '>=', '<=', '!=',
660
# 'in', 'not in', 'is', 'is not'
661
EXPR = auto()
662
BOR = EXPR # '|'
663
BXOR = auto() # '^'
664
BAND = auto() # '&'
665
SHIFT = auto() # '<<', '>>'
666
ARITH = auto() # '+', '-'
667
TERM = auto() # '*', '@', '/', '%', '//'
668
FACTOR = auto() # unary '+', '-', '~'
669
POWER = auto() # '**'
670
AWAIT = auto() # 'await'
671
ATOM = auto()
672
673
def next(self):
674
try:
675
return self.__class__(self + 1)
676
except ValueError:
677
return self
678
679
680
_SINGLE_QUOTES = ("'", '"')
681
_MULTI_QUOTES = ('"""', "'''")
682
_ALL_QUOTES = (*_SINGLE_QUOTES, *_MULTI_QUOTES)
683
684
class _Unparser(NodeVisitor):
685
"""Methods in this class recursively traverse an AST and
686
output source code for the abstract syntax; original formatting
687
is disregarded."""
688
689
def __init__(self, *, _avoid_backslashes=False):
690
self._source = []
691
self._precedences = {}
692
self._type_ignores = {}
693
self._indent = 0
694
self._avoid_backslashes = _avoid_backslashes
696
697
def interleave(self, inter, f, seq):
698
"""Call f on each item in seq, calling inter() in between."""
699
seq = iter(seq)
700
try:
701
f(next(seq))
702
except StopIteration:
703
pass
704
else:
705
for x in seq:
706
inter()
707
f(x)
708
709
def items_view(self, traverser, items):
710
"""Traverse and separate the given *items* with a comma and append it to
711
the buffer. If *items* is a single item sequence, a trailing comma
712
will be added."""
713
if len(items) == 1:
714
traverser(items[0])
715
self.write(",")
716
else:
717
self.interleave(lambda: self.write(", "), traverser, items)
718
719
def maybe_newline(self):
720
"""Adds a newline if it isn't the start of generated source"""
721
if self._source:
722
self.write("\n")
723
724
def fill(self, text=""):
725
"""Indent a piece of text and append it, according to the current
726
indentation level"""
727
self.maybe_newline()
728
self.write(" " * self._indent + text)
730
def write(self, *text):
731
"""Add new source parts"""
732
self._source.extend(text)
734
@contextmanager
735
def buffered(self, buffer = None):
736
if buffer is None:
737
buffer = []
739
original_source = self._source
740
self._source = buffer
741
yield buffer
742
self._source = original_source
744
@contextmanager
745
def block(self, *, extra = None):
746
"""A context manager for preparing the source for blocks. It adds
747
the character':', increases the indentation on enter and decreases
748
the indentation on exit. If *extra* is given, it will be directly
749
appended after the colon character.
750
"""
751
self.write(":")
752
if extra:
753
self.write(extra)
754
self._indent += 1
755
yield
756
self._indent -= 1
758
@contextmanager
759
def delimit(self, start, end):
760
"""A context manager for preparing the source for expressions. It adds
761
*start* to the buffer and enters, after exit it adds *end*."""
762
763
self.write(start)
764
yield
765
self.write(end)
766
767
def delimit_if(self, start, end, condition):
768
if condition:
769
return self.delimit(start, end)
770
else:
771
return nullcontext()
772
773
def require_parens(self, precedence, node):
774
"""Shortcut to adding precedence related parens"""
775
return self.delimit_if("(", ")", self.get_precedence(node) > precedence)
776
777
def get_precedence(self, node):
778
return self._precedences.get(node, _Precedence.TEST)
779
780
def set_precedence(self, precedence, *nodes):
781
for node in nodes:
782
self._precedences[node] = precedence
783
784
def get_raw_docstring(self, node):
785
"""If a docstring node is found in the body of the *node* parameter,
786
return that docstring node, None otherwise.
787
788
Logic mirrored from ``_PyAST_GetDocString``."""
789
if not isinstance(
790
node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)
791
) or len(node.body) < 1:
792
return None
793
node = node.body[0]
794
if not isinstance(node, Expr):
795
return None
796
node = node.value
797
if isinstance(node, Constant) and isinstance(node.value, str):
798
return node
799
800
def get_type_comment(self, node):
801
comment = self._type_ignores.get(node.lineno) or node.type_comment
802
if comment is not None:
803
return f" # type: {comment}"
804
805
def traverse(self, node):
806
if isinstance(node, list):
807
for item in node:
808
self.traverse(item)
809
else:
810
super().visit(node)
811
812
# Note: as visit() resets the output text, do NOT rely on
813
# NodeVisitor.generic_visit to handle any nodes (as it calls back in to
814
# the subclass visit() method, which resets self._source to an empty list)
815
def visit(self, node):
816
"""Outputs a source code string that, if converted back to an ast
817
(using ast.parse) will generate an AST equivalent to *node*"""
818
self._source = []
819
self.traverse(node)
820
return "".join(self._source)
821
822
def _write_docstring_and_traverse_body(self, node):
823
if (docstring := self.get_raw_docstring(node)):
824
self._write_docstring(docstring)
825
self.traverse(node.body[1:])
826
else:
827
self.traverse(node.body)
828
829
def visit_Module(self, node):
830
self._type_ignores = {
831
ignore.lineno: f"ignore{ignore.tag}"
832
for ignore in node.type_ignores
833
}
834
self._write_docstring_and_traverse_body(node)
835
self._type_ignores.clear()
837
def visit_FunctionType(self, node):
838
with self.delimit("(", ")"):
839
self.interleave(
840
lambda: self.write(", "), self.traverse, node.argtypes
841
)
842
843
self.write(" -> ")
844
self.traverse(node.returns)
845
846
def visit_Expr(self, node):
847
self.fill()
848
self.set_precedence(_Precedence.YIELD, node.value)
849
self.traverse(node.value)
850
851
def visit_NamedExpr(self, node):
852
with self.require_parens(_Precedence.NAMED_EXPR, node):
853
self.set_precedence(_Precedence.ATOM, node.target, node.value)
854
self.traverse(node.target)
855
self.write(" := ")
856
self.traverse(node.value)
857
858
def visit_Import(self, node):
859
self.fill("import ")
860
self.interleave(lambda: self.write(", "), self.traverse, node.names)
861
862
def visit_ImportFrom(self, node):
863
self.fill("from ")
864
self.write("." * (node.level or 0))
865
if node.module:
866
self.write(node.module)
867
self.write(" import ")
868
self.interleave(lambda: self.write(", "), self.traverse, node.names)
869
870
def visit_Assign(self, node):
871
self.fill()
872
for target in node.targets:
873
self.set_precedence(_Precedence.TUPLE, target)
874
self.traverse(target)
875
self.write(" = ")
876
self.traverse(node.value)
877
if type_comment := self.get_type_comment(node):
878
self.write(type_comment)
879
880
def visit_AugAssign(self, node):
881
self.fill()
882
self.traverse(node.target)
883
self.write(" " + self.binop[node.op.__class__.__name__] + "= ")
884
self.traverse(node.value)
885
886
def visit_AnnAssign(self, node):
887
self.fill()
888
with self.delimit_if("(", ")", not node.simple and isinstance(node.target, Name)):
889
self.traverse(node.target)
890
self.write(": ")
891
self.traverse(node.annotation)
892
if node.value:
893
self.write(" = ")
894
self.traverse(node.value)
895
896
def visit_Return(self, node):
897
self.fill("return")
898
if node.value:
899
self.write(" ")
900
self.traverse(node.value)
901
902
def visit_Pass(self, node):
903
self.fill("pass")
904
905
def visit_Break(self, node):
906
self.fill("break")
907
908
def visit_Continue(self, node):
909
self.fill("continue")
910
911
def visit_Delete(self, node):
912
self.fill("del ")
913
self.interleave(lambda: self.write(", "), self.traverse, node.targets)
914
915
def visit_Assert(self, node):
916
self.fill("assert ")
917
self.traverse(node.test)
918
if node.msg:
919
self.write(", ")
920
self.traverse(node.msg)
921
922
def visit_Global(self, node):
923
self.fill("global ")
924
self.interleave(lambda: self.write(", "), self.write, node.names)
925
926
def visit_Nonlocal(self, node):
927
self.fill("nonlocal ")
928
self.interleave(lambda: self.write(", "), self.write, node.names)
929
930
def visit_Await(self, node):
931
with self.require_parens(_Precedence.AWAIT, node):
932
self.write("await")
933
if node.value:
934
self.write(" ")
935
self.set_precedence(_Precedence.ATOM, node.value)
936
self.traverse(node.value)
937
938
def visit_Yield(self, node):
939
with self.require_parens(_Precedence.YIELD, node):
940
self.write("yield")
941
if node.value:
942
self.write(" ")
943
self.set_precedence(_Precedence.ATOM, node.value)
944
self.traverse(node.value)
945
946
def visit_YieldFrom(self, node):
947
with self.require_parens(_Precedence.YIELD, node):
948
self.write("yield from ")
949
if not node.value:
950
raise ValueError("Node can't be used without a value attribute.")
951
self.set_precedence(_Precedence.ATOM, node.value)
952
self.traverse(node.value)
953
954
def visit_Raise(self, node):
955
self.fill("raise")
956
if not node.exc:
957
if node.cause:
958
raise ValueError(f"Node can't use cause without an exception.")
959
return
960
self.write(" ")
961
self.traverse(node.exc)
962
if node.cause:
963
self.write(" from ")
964
self.traverse(node.cause)
965
967
self.fill("try")
968
with self.block():
969
self.traverse(node.body)
970
for ex in node.handlers:
971
self.traverse(ex)
972
if node.orelse:
973
self.fill("else")
974
with self.block():
975
self.traverse(node.orelse)
976
if node.finalbody:
977
self.fill("finally")
978
with self.block():
979
self.traverse(node.finalbody)
980
981
def visit_Try(self, node):
982
prev_in_try_star = self._in_try_star
983
try:
984
self._in_try_star = False
985
self.do_visit_try(node)
986
finally:
987
self._in_try_star = prev_in_try_star
988
989
def visit_TryStar(self, node):
990
prev_in_try_star = self._in_try_star
991
try:
992
self._in_try_star = True
993
self.do_visit_try(node)
994
finally:
995
self._in_try_star = prev_in_try_star
996
997
def visit_ExceptHandler(self, node):
998
self.fill("except*" if self._in_try_star else "except")
999
if node.type:
1000
self.write(" ")
1001
self.traverse(node.type)
1002
if node.name:
1003
self.write(" as ")
1004
self.write(node.name)
1005
with self.block():
1006
self.traverse(node.body)
1007
1008
def visit_ClassDef(self, node):
1009
self.maybe_newline()
1010
for deco in node.decorator_list:
1011
self.fill("@")
1012
self.traverse(deco)
1013
self.fill("class " + node.name)
1014
with self.delimit_if("(", ")", condition = node.bases or node.keywords):
1015
comma = False
1016
for e in node.bases:
1017
if comma:
1018
self.write(", ")
1019
else:
1020
comma = True
1021
self.traverse(e)
1022
for e in node.keywords:
1023
if comma:
1024
self.write(", ")
1025
else:
1026
comma = True
1027
self.traverse(e)
1028
1029
with self.block():
1030
self._write_docstring_and_traverse_body(node)
1031
1032
def visit_FunctionDef(self, node):
1033
self._function_helper(node, "def")
1034
1035
def visit_AsyncFunctionDef(self, node):
1036
self._function_helper(node, "async def")
1038
def _function_helper(self, node, fill_suffix):
1039
self.maybe_newline()
1040
for deco in node.decorator_list:
1041
self.fill("@")
1042
self.traverse(deco)
1043
def_str = fill_suffix + " " + node.name
1044
self.fill(def_str)
1045
with self.delimit("(", ")"):
1046
self.traverse(node.args)
1047
if node.returns:
1048
self.write(" -> ")
1049
self.traverse(node.returns)
1050
with self.block(extra=self.get_type_comment(node)):
1051
self._write_docstring_and_traverse_body(node)
1052
1053
def visit_For(self, node):
1054
self._for_helper("for ", node)
1055
1056
def visit_AsyncFor(self, node):
1057
self._for_helper("async for ", node)
1059
def _for_helper(self, fill, node):
1060
self.fill(fill)
1061
self.set_precedence(_Precedence.TUPLE, node.target)
1062
self.traverse(node.target)
1063
self.write(" in ")
1064
self.traverse(node.iter)
1065
with self.block(extra=self.get_type_comment(node)):
1066
self.traverse(node.body)
1067
if node.orelse:
1068
self.fill("else")
1069
with self.block():
1070
self.traverse(node.orelse)
1071
1072
def visit_If(self, node):
1073
self.fill("if ")
1074
self.traverse(node.test)
1075
with self.block():
1076
self.traverse(node.body)
1077
# collapse nested ifs into equivalent elifs.
1078
while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
1079
node = node.orelse[0]
1080
self.fill("elif ")
1081
self.traverse(node.test)
1082
with self.block():
1083
self.traverse(node.body)
1084
# final else
1085
if node.orelse:
1086
self.fill("else")
1087
with self.block():
1088
self.traverse(node.orelse)
1089
1090
def visit_While(self, node):
1091
self.fill("while ")
1092
self.traverse(node.test)
1093
with self.block():
1094
self.traverse(node.body)
1095
if node.orelse:
1096
self.fill("else")
1097
with self.block():
1098
self.traverse(node.orelse)
1099
1100
def visit_With(self, node):
1101
self.fill("with ")
1102
self.interleave(lambda: self.write(", "), self.traverse, node.items)
1103
with self.block(extra=self.get_type_comment(node)):
1104
self.traverse(node.body)
1105
1106
def visit_AsyncWith(self, node):
1107
self.fill("async with ")
1108
self.interleave(lambda: self.write(", "), self.traverse, node.items)
1109
with self.block(extra=self.get_type_comment(node)):
1110
self.traverse(node.body)
1111
1112
def _str_literal_helper(
1113
self, string, *, quote_types=_ALL_QUOTES, escape_special_whitespace=False
1114
):
1115
"""Helper for writing string literals, minimizing escapes.
1116
Returns the tuple (string literal to write, possible quote types).
1117
"""
1118
def escape_char(c):
1119
# \n and \t are non-printable, but we only escape them if
1120
# escape_special_whitespace is True
1121
if not escape_special_whitespace and c in "\n\t":
1122
return c
1123
# Always escape backslashes and other non-printable characters
1124
if c == "\\" or not c.isprintable():
1125
return c.encode("unicode_escape").decode("ascii")
1126
return c
1127
1128
escaped_string = "".join(map(escape_char, string))
1129
possible_quotes = quote_types
1130
if "\n" in escaped_string:
1131
possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES]
1132
possible_quotes = [q for q in possible_quotes if q not in escaped_string]
1133
if not possible_quotes:
1134
# If there aren't any possible_quotes, fallback to using repr
1135
# on the original string. Try to use a quote from quote_types,
1136
# e.g., so that we use triple quotes for docstrings.
1137
string = repr(string)
1138
quote = next((q for q in quote_types if string[0] in q), string[0])
1139
return string[1:-1], [quote]
1140
if escaped_string:
1141
# Sort so that we prefer '''"''' over """\""""
1142
possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1])
1143
# If we're using triple quotes and we'd need to escape a final
1144
# quote, escape it
1145
if possible_quotes[0][0] == escaped_string[-1]:
1146
assert len(possible_quotes[0]) == 3
1147
escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1]
1148
return escaped_string, possible_quotes
1149
1150
def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES):
1151
"""Write string literal value with a best effort attempt to avoid backslashes."""
1152
string, quote_types = self._str_literal_helper(string, quote_types=quote_types)
1153
quote_type = quote_types[0]
1154
self.write(f"{quote_type}{string}{quote_type}")
1155
1156
def visit_JoinedStr(self, node):
1157
self.write("f")
1158
if self._avoid_backslashes:
1159
with self.buffered() as buffer:
1160
self._write_fstring_inner(node)
1161
return self._write_str_avoiding_backslashes("".join(buffer))
1162
1163
# If we don't need to avoid backslashes globally (i.e., we only need
1164
# to avoid them inside FormattedValues), it's cosmetically preferred
1165
# to use escaped whitespace. That is, it's preferred to use backslashes
1166
# for cases like: f"{x}\n". To accomplish this, we keep track of what
1167
# in our buffer corresponds to FormattedValues and what corresponds to
1168
# Constant parts of the f-string, and allow escapes accordingly.
1170
for value in node.values:
1171
with self.buffered() as buffer:
1172
self._write_fstring_inner(value)
1173
fstring_parts.append(
1174
("".join(buffer), isinstance(value, Constant))
1175
)
1176
1177
new_fstring_parts = []
1178
quote_types = list(_ALL_QUOTES)
1179
for value, is_constant in fstring_parts:
1180
value, quote_types = self._str_literal_helper(
1181
value,
1182
quote_types=quote_types,
1183
escape_special_whitespace=is_constant,
1185
new_fstring_parts.append(value)
1186
1187
value = "".join(new_fstring_parts)
1188
quote_type = quote_types[0]
1189
self.write(f"{quote_type}{value}{quote_type}")
1191
def _write_fstring_inner(self, node):
1192
if isinstance(node, JoinedStr):
1193
# for both the f-string itself, and format_spec
1194
for value in node.values:
1195
self._write_fstring_inner(value)
1196
elif isinstance(node, Constant) and isinstance(node.value, str):
1197
value = node.value.replace("{", "{{").replace("}", "}}")
1198
self.write(value)
1199
elif isinstance(node, FormattedValue):
1200
self.visit_FormattedValue(node)
1201
else:
1202
raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")
1204
def visit_FormattedValue(self, node):
1205
def unparse_inner(inner):
1206
unparser = type(self)(_avoid_backslashes=True)
1207
unparser.set_precedence(_Precedence.TEST.next(), inner)
1208
return unparser.visit(inner)
1209
1210
with self.delimit("{", "}"):
1211
expr = unparse_inner(node.value)
1212
if "\\" in expr:
1213
raise ValueError(
1214
"Unable to avoid backslash in f-string expression part"
1215
)
1216
if expr.startswith("{"):
1217
# Separate pair of opening brackets as "{ {"
1218
self.write(" ")
1219
self.write(expr)
1220
if node.conversion != -1:
1221
self.write(f"!{chr(node.conversion)}")
1222
if node.format_spec:
1223
self.write(":")
1224
self._write_fstring_inner(node.format_spec)
1225
1226
def visit_Name(self, node):
1227
self.write(node.id)
1228
1229
def _write_docstring(self, node):
1230
self.fill()
1231
if node.kind == "u":
1232
self.write("u")
1233
self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES)
1235
def _write_constant(self, value):
1236
if isinstance(value, (float, complex)):
1237
# Substitute overflowing decimal literal for AST infinities,
1238
# and inf - inf for NaNs.
1239
self.write(
1240
repr(value)
1241
.replace("inf", _INFSTR)
1242
.replace("nan", f"({_INFSTR}-{_INFSTR})")
1243
)
1244
elif self._avoid_backslashes and isinstance(value, str):
1245
self._write_str_avoiding_backslashes(value)
1246
else:
1247
self.write(repr(value))
1248
1249
def visit_Constant(self, node):
1250
value = node.value
1251
if isinstance(value, tuple):
1252
with self.delimit("(", ")"):
1253
self.items_view(self._write_constant, value)
1254
elif value is ...:
1255
self.write("...")
1256
else:
1257
if node.kind == "u":
1258
self.write("u")
1259
self._write_constant(node.value)
1260
1261
def visit_List(self, node):
1262
with self.delimit("[", "]"):
1263
self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1264
1265
def visit_ListComp(self, node):
1266
with self.delimit("[", "]"):
1267
self.traverse(node.elt)
1268
for gen in node.generators:
1269
self.traverse(gen)
1270
1271
def visit_GeneratorExp(self, node):
1272
with self.delimit("(", ")"):
1273
self.traverse(node.elt)
1274
for gen in node.generators:
1275
self.traverse(gen)
1276
1277
def visit_SetComp(self, node):
1278
with self.delimit("{", "}"):
1279
self.traverse(node.elt)
1280
for gen in node.generators:
1281
self.traverse(gen)
1282
1283
def visit_DictComp(self, node):
1284
with self.delimit("{", "}"):
1285
self.traverse(node.key)
1286
self.write(": ")
1287
self.traverse(node.value)
1288
for gen in node.generators:
1289
self.traverse(gen)
1290
1291
def visit_comprehension(self, node):
1292
if node.is_async:
1293
self.write(" async for ")
1294
else:
1295
self.write(" for ")
1296
self.set_precedence(_Precedence.TUPLE, node.target)
1297
self.traverse(node.target)
1298
self.write(" in ")
1299
self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs)
1300
self.traverse(node.iter)
1301
for if_clause in node.ifs:
1302
self.write(" if ")
1303
self.traverse(if_clause)
1304
1305
def visit_IfExp(self, node):
1306
with self.require_parens(_Precedence.TEST, node):
1307
self.set_precedence(_Precedence.TEST.next(), node.body, node.test)
1308
self.traverse(node.body)
1309
self.write(" if ")
1310
self.traverse(node.test)
1311
self.write(" else ")
1312
self.set_precedence(_Precedence.TEST, node.orelse)
1313
self.traverse(node.orelse)
1314
1315
def visit_Set(self, node):
1316
if node.elts:
1317
with self.delimit("{", "}"):
1318
self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1319
else:
1320
# `{}` would be interpreted as a dictionary literal, and
1321
# `set` might be shadowed. Thus:
1322
self.write('{*()}')
1323
1324
def visit_Dict(self, node):
1325
def write_key_value_pair(k, v):
1326
self.traverse(k)
1327
self.write(": ")
1328
self.traverse(v)
1329
1330
def write_item(item):
1331
k, v = item
1332
if k is None:
1333
# for dictionary unpacking operator in dicts {**{'y': 2}}
1334
# see PEP 448 for details
1335
self.write("**")
1336
self.set_precedence(_Precedence.EXPR, v)
1337
self.traverse(v)
1338
else:
1339
write_key_value_pair(k, v)
1340
1341
with self.delimit("{", "}"):
1342
self.interleave(
1343
lambda: self.write(", "), write_item, zip(node.keys, node.values)
1344
)
1345
1346
def visit_Tuple(self, node):
1347
with self.delimit_if(
1348
"(",
1349
")",
1350
len(node.elts) == 0 or self.get_precedence(node) > _Precedence.TUPLE
1351
):
1352
self.items_view(self.traverse, node.elts)
1353
1354
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
1355
unop_precedence = {
1356
"not": _Precedence.NOT,
1357
"~": _Precedence.FACTOR,
1358
"+": _Precedence.FACTOR,
1359
"-": _Precedence.FACTOR,
1361
1362
def visit_UnaryOp(self, node):
1363
operator = self.unop[node.op.__class__.__name__]
1364
operator_precedence = self.unop_precedence[operator]
1365
with self.require_parens(operator_precedence, node):
1366
self.write(operator)
1367
# factor prefixes (+, -, ~) shouldn't be separated
1368
# from the value they belong, (e.g: +1 instead of + 1)
1369
if operator_precedence is not _Precedence.FACTOR:
1370
self.write(" ")
1371
self.set_precedence(operator_precedence, node.operand)
1372
self.traverse(node.operand)
1373
1374
binop = {
1375
"Add": "+",
1376
"Sub": "-",
1377
"Mult": "*",
1378
"MatMult": "@",
1379
"Div": "/",
1380
"Mod": "%",
1381
"LShift": "<<",
1382
"RShift": ">>",
1383
"BitOr": "|",
1384
"BitXor": "^",
1385
"BitAnd": "&",
1386
"FloorDiv": "//",
1387
"Pow": "**",
1388
}
1389
1390
binop_precedence = {
1391
"+": _Precedence.ARITH,
1392
"-": _Precedence.ARITH,
1393
"*": _Precedence.TERM,
1394
"@": _Precedence.TERM,
1395
"/": _Precedence.TERM,
1396
"%": _Precedence.TERM,
1397
"<<": _Precedence.SHIFT,
1398
">>": _Precedence.SHIFT,
1399
"|": _Precedence.BOR,
1400
"^": _Precedence.BXOR,
1401
"&": _Precedence.BAND,
1402
"//": _Precedence.TERM,
1403
"**": _Precedence.POWER,
1404
}
1405
1406
binop_rassoc = frozenset(("**",))
1407
def visit_BinOp(self, node):
1408
operator = self.binop[node.op.__class__.__name__]
1409
operator_precedence = self.binop_precedence[operator]
1410
with self.require_parens(operator_precedence, node):
1411
if operator in self.binop_rassoc:
1412
left_precedence = operator_precedence.next()
1413
right_precedence = operator_precedence
1414
else:
1415
left_precedence = operator_precedence
1416
right_precedence = operator_precedence.next()
1417
1418
self.set_precedence(left_precedence, node.left)
1419
self.traverse(node.left)
1420
self.write(f" {operator} ")
1421
self.set_precedence(right_precedence, node.right)
1422
self.traverse(node.right)
1423
1424
cmpops = {
1425
"Eq": "==",
1426
"NotEq": "!=",
1427
"Lt": "<",
1428
"LtE": "<=",
1429
"Gt": ">",
1430
"GtE": ">=",
1431
"Is": "is",
1432
"IsNot": "is not",
1433
"In": "in",
1434
"NotIn": "not in",
1435
}
1436
1437
def visit_Compare(self, node):
1438
with self.require_parens(_Precedence.CMP, node):
1439
self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators)
1440
self.traverse(node.left)
1441
for o, e in zip(node.ops, node.comparators):
1442
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
1443
self.traverse(e)
1445
boolops = {"And": "and", "Or": "or"}
1446
boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR}
1447
1448
def visit_BoolOp(self, node):
1449
operator = self.boolops[node.op.__class__.__name__]
1450
operator_precedence = self.boolop_precedence[operator]
1451
1452
def increasing_level_traverse(node):
1453
nonlocal operator_precedence
1454
operator_precedence = operator_precedence.next()
1455
self.set_precedence(operator_precedence, node)
1456
self.traverse(node)
1457
1458
with self.require_parens(operator_precedence, node):
1459
s = f" {operator} "
1460
self.interleave(lambda: self.write(s), increasing_level_traverse, node.values)
1461
1462
def visit_Attribute(self, node):
1463
self.set_precedence(_Precedence.ATOM, node.value)
1464
self.traverse(node.value)
1465
# Special case: 3.__abs__() is a syntax error, so if node.value
1466
# is an integer literal then we need to either parenthesize
1467
# it or add an extra space to get 3 .__abs__().
1468
if isinstance(node.value, Constant) and isinstance(node.value.value, int):
1469
self.write(" ")
1470
self.write(".")
1471
self.write(node.attr)
1472
1473
def visit_Call(self, node):
1474
self.set_precedence(_Precedence.ATOM, node.func)
1475
self.traverse(node.func)
1476
with self.delimit("(", ")"):
1477
comma = False
1478
for e in node.args:
1479
if comma:
1480
self.write(", ")
1481
else:
1482
comma = True
1483
self.traverse(e)
1484
for e in node.keywords:
1485
if comma:
1486
self.write(", ")
1487
else:
1488
comma = True
1489
self.traverse(e)
1490
1491
def visit_Subscript(self, node):
1492
def is_non_empty_tuple(slice_value):
1493
return (
1494
isinstance(slice_value, Tuple)
1495
and slice_value.elts
1496
)
1497
1498
self.set_precedence(_Precedence.ATOM, node.value)
1499
self.traverse(node.value)
1500
with self.delimit("[", "]"):
1501
if is_non_empty_tuple(node.slice):
1502
# parentheses can be omitted if the tuple isn't empty
1503
self.items_view(self.traverse, node.slice.elts)
1504
else:
1505
self.traverse(node.slice)
1506
1507
def visit_Starred(self, node):
1508
self.write("*")
1509
self.set_precedence(_Precedence.EXPR, node.value)
1510
self.traverse(node.value)
1511
1512
def visit_Ellipsis(self, node):
1513
self.write("...")
1514
1515
def visit_Slice(self, node):
1516
if node.lower:
1517
self.traverse(node.lower)
1518
self.write(":")
1519
if node.upper:
1520
self.traverse(node.upper)
1521
if node.step:
1522
self.write(":")
1523
self.traverse(node.step)
1524
1525
def visit_Match(self, node):
1526
self.fill("match ")
1527
self.traverse(node.subject)
1528
with self.block():
1529
for case in node.cases:
1530
self.traverse(case)
1531
1532
def visit_arg(self, node):
1533
self.write(node.arg)
1534
if node.annotation:
1535
self.write(": ")
1536
self.traverse(node.annotation)
1537
1538
def visit_arguments(self, node):
1539
first = True
1540
# normal arguments
1541
all_args = node.posonlyargs + node.args
1542
defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults
1543
for index, elements in enumerate(zip(all_args, defaults), 1):
1544
a, d = elements
1545
if first:
1546
first = False
1547
else:
1548
self.write(", ")
1549
self.traverse(a)
1550
if d:
1551
self.write("=")
1552
self.traverse(d)
1553
if index == len(node.posonlyargs):
1554
self.write(", /")
1555
1556
# varargs, or bare '*' if no varargs but keyword-only arguments present
1557
if node.vararg or node.kwonlyargs:
1558
if first:
1559
first = False
1560
else:
1561
self.write(", ")
1562
self.write("*")
1563
if node.vararg:
1564
self.write(node.vararg.arg)
1565
if node.vararg.annotation:
1566
self.write(": ")
1567
self.traverse(node.vararg.annotation)
1568
1569
# keyword-only arguments
1570
if node.kwonlyargs:
1571
for a, d in zip(node.kwonlyargs, node.kw_defaults):
1572
self.write(", ")
1573
self.traverse(a)
1574
if d:
1575
self.write("=")
1576
self.traverse(d)
1577
1578
# kwargs
1579
if node.kwarg:
1580
if first:
1581
first = False
1582
else:
1583
self.write(", ")
1584
self.write("**" + node.kwarg.arg)
1585
if node.kwarg.annotation:
1586
self.write(": ")
1587
self.traverse(node.kwarg.annotation)
1588
1589
def visit_keyword(self, node):
1590
if node.arg is None:
1591
self.write("**")
1592
else:
1593
self.write(node.arg)
1594
self.write("=")
1595
self.traverse(node.value)
1596
1597
def visit_Lambda(self, node):
1598
with self.require_parens(_Precedence.TEST, node):
1599
self.write("lambda")
1600
with self.buffered() as buffer:
1601
self.traverse(node.args)
1602
if buffer:
1603
self.write(" ", *buffer)
1604
self.write(": ")
1605
self.set_precedence(_Precedence.TEST, node.body)
1606
self.traverse(node.body)
1607
1608
def visit_alias(self, node):
1609
self.write(node.name)
1610
if node.asname:
1611
self.write(" as " + node.asname)
1612
1613
def visit_withitem(self, node):
1614
self.traverse(node.context_expr)
1615
if node.optional_vars:
1616
self.write(" as ")
1617
self.traverse(node.optional_vars)
1618
1619
def visit_match_case(self, node):
1620
self.fill("case ")
1621
self.traverse(node.pattern)
1622
if node.guard:
1623
self.write(" if ")
1624
self.traverse(node.guard)
1625
with self.block():
1626
self.traverse(node.body)
1627
1628
def visit_MatchValue(self, node):
1629
self.traverse(node.value)
1630
1631
def visit_MatchSingleton(self, node):
1632
self._write_constant(node.value)
1633
1634
def visit_MatchSequence(self, node):
1635
with self.delimit("[", "]"):
1636
self.interleave(
1637
lambda: self.write(", "), self.traverse, node.patterns
1638
)
1639
1640
def visit_MatchStar(self, node):
1641
name = node.name
1642
if name is None:
1643
name = "_"
1644
self.write(f"*{name}")
1645
1646
def visit_MatchMapping(self, node):
1647
def write_key_pattern_pair(pair):
1648
k, p = pair
1649
self.traverse(k)
1650
self.write(": ")
1651
self.traverse(p)
1652
1653
with self.delimit("{", "}"):
1654
keys = node.keys
1655
self.interleave(
1656
lambda: self.write(", "),
1657
write_key_pattern_pair,
1658
zip(keys, node.patterns, strict=True),
1659
)
1660
rest = node.rest
1661
if rest is not None:
1662
if keys:
1663
self.write(", ")
1664
self.write(f"**{rest}")
1665
1666
def visit_MatchClass(self, node):
1667
self.set_precedence(_Precedence.ATOM, node.cls)
1668
self.traverse(node.cls)
1669
with self.delimit("(", ")"):
1670
patterns = node.patterns
1671
self.interleave(
1672
lambda: self.write(", "), self.traverse, patterns
1673
)
1674
attrs = node.kwd_attrs
1675
if attrs:
1676
def write_attr_pattern(pair):
1677
attr, pattern = pair
1678
self.write(f"{attr}=")
1679
self.traverse(pattern)
1680
1681
if patterns:
1682
self.write(", ")
1683
self.interleave(
1684
lambda: self.write(", "),
1685
write_attr_pattern,
1686
zip(attrs, node.kwd_patterns, strict=True),
1687
)
1688
1689
def visit_MatchAs(self, node):
1690
name = node.name
1691
pattern = node.pattern
1692
if name is None:
1693
self.write("_")
1694
elif pattern is None:
1695
self.write(node.name)
1696
else:
1697
with self.require_parens(_Precedence.TEST, node):
1698
self.set_precedence(_Precedence.BOR, node.pattern)
1699
self.traverse(node.pattern)
1700
self.write(f" as {node.name}")
1701
1702
def visit_MatchOr(self, node):
1703
with self.require_parens(_Precedence.BOR, node):
1704
self.set_precedence(_Precedence.BOR.next(), *node.patterns)
1705
self.interleave(lambda: self.write(" | "), self.traverse, node.patterns)
1706
1707
def unparse(ast_obj):
1708
unparser = _Unparser()
1709
return unparser.visit(ast_obj)
1710
1711
1712
def main():
1713
import argparse
1714
1715
parser = argparse.ArgumentParser(prog='python -m ast')
1716
parser.add_argument('infile', type=argparse.FileType(mode='rb'), nargs='?',
1717
default='-',
1718
help='the file to parse; defaults to stdin')
1719
parser.add_argument('-m', '--mode', default='exec',
1720
choices=('exec', 'single', 'eval', 'func_type'),
1721
help='specify what kind of code must be parsed')
1722
parser.add_argument('--no-type-comments', default=True, action='store_false',
1723
help="don't add information about type comments")
1724
parser.add_argument('-a', '--include-attributes', action='store_true',
1725
help='include attributes such as line numbers and '
1726
'column offsets')
1727
parser.add_argument('-i', '--indent', type=int, default=3,
1728
help='indentation of nodes (number of spaces)')
1729
args = parser.parse_args()
1730
1731
with args.infile as infile:
1732
source = infile.read()
1733
tree = parse(source, args.infile.name, args.mode, type_comments=args.no_type_comments)
1734
print(dump(tree, include_attributes=args.include_attributes, indent=args.indent))
1735
1736
if __name__ == '__main__':
1737
main()