Import tensorflow

This commit is contained in:
2026-02-15 21:45:42 -08:00
parent f3e8b90764
commit c530630153
20524 changed files with 9017694 additions and 25 deletions
@@ -0,0 +1,30 @@
# coding=utf-8
"""Pasta enables AST-based transformations on python source code."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pasta.base import annotate
from pasta.base import ast_utils
from pasta.base import codegen
def parse(src):
t = ast_utils.parse(src)
annotator = annotate.AstAnnotator(src)
annotator.visit(t)
return t
def dump(tree):
return codegen.to_str(tree)
@@ -0,0 +1,23 @@
# coding=utf-8
"""Errors that can occur during augmentation."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class InvalidAstError(Exception):
"""Occurs when the syntax tree does not meet some expected condition."""
@@ -0,0 +1,217 @@
# coding=utf-8
"""Functions for dealing with import statements."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import copy
import logging
from pasta.augment import errors
from pasta.base import ast_utils
from pasta.base import scope
def add_import(tree, name_to_import, asname=None, from_import=True, merge_from_imports=True):
"""Adds an import to the module.
This function will try to ensure not to create duplicate imports. If name_to_import is
already imported, it will return the existing import. This is true even if asname is set
(asname will be ignored, and the existing name will be returned).
If the import would create a name that already exists in the scope given by tree, this
function will "import as", and append "_x" to the asname where x is the smallest positive
integer generating a unique name.
Arguments:
tree: (ast.Module) Module AST to modify.
name_to_import: (string) The absolute name to import.
asname: (string) The alias for the import ("import name_to_import as asname")
from_import: (boolean) If True, import the name using an ImportFrom node.
merge_from_imports: (boolean) If True, merge a newly inserted ImportFrom
node into an existing ImportFrom node, if applicable.
Returns:
The name (as a string) that can be used to reference the imported name. This
can be the fully-qualified name, the basename, or an alias name.
"""
sc = scope.analyze(tree)
# Don't add anything if it's already imported
if name_to_import in sc.external_references:
existing_ref = next((ref for ref in sc.external_references[name_to_import]
if ref.name_ref is not None), None)
if existing_ref:
return existing_ref.name_ref.id
import_node = None
added_name = None
def make_safe_alias_node(alias_name, asname):
# Try to avoid name conflicts
new_alias = ast.alias(name=alias_name, asname=asname)
imported_name = asname or alias_name
counter = 0
while imported_name in sc.names:
counter += 1
imported_name = new_alias.asname = '%s_%d' % (asname or alias_name,
counter)
return new_alias
# Add an ImportFrom node if requested and possible
if from_import and '.' in name_to_import:
from_module, alias_name = name_to_import.rsplit('.', 1)
new_alias = make_safe_alias_node(alias_name, asname)
if merge_from_imports:
# Try to add to an existing ImportFrom from the same module
existing_from_import = next(
(node for node in tree.body if isinstance(node, ast.ImportFrom)
and node.module == from_module and node.level == 0), None)
if existing_from_import:
existing_from_import.names.append(new_alias)
return new_alias.asname or new_alias.name
# Create a new node for this import
import_node = ast.ImportFrom(module=from_module, names=[new_alias], level=0)
# If not already created as an ImportFrom, create a normal Import node
if not import_node:
new_alias = make_safe_alias_node(alias_name=name_to_import, asname=asname)
import_node = ast.Import(
names=[new_alias])
# Insert the node at the top of the module and return the name in scope
tree.body.insert(1 if ast_utils.has_docstring(tree) else 0, import_node)
return new_alias.asname or new_alias.name
def split_import(sc, node, alias_to_remove):
"""Split an import node by moving the given imported alias into a new import.
Arguments:
sc: (scope.Scope) Scope computed on whole tree of the code being modified.
node: (ast.Import|ast.ImportFrom) An import node to split.
alias_to_remove: (ast.alias) The import alias node to remove. This must be a
child of the given `node` argument.
Raises:
errors.InvalidAstError: if `node` is not appropriately contained in the tree
represented by the scope `sc`.
"""
parent = sc.parent(node)
parent_list = None
for a in ('body', 'orelse', 'finalbody'):
if hasattr(parent, a) and node in getattr(parent, a):
parent_list = getattr(parent, a)
break
else:
raise errors.InvalidAstError('Unable to find list containing import %r on '
'parent node %r' % (node, parent))
idx = parent_list.index(node)
new_import = copy.deepcopy(node)
new_import.names = [alias_to_remove]
node.names.remove(alias_to_remove)
parent_list.insert(idx + 1, new_import)
return new_import
def get_unused_import_aliases(tree, sc=None):
"""Get the import aliases that aren't used.
Arguments:
tree: (ast.AST) An ast to find imports in.
sc: A scope.Scope representing tree (generated from scratch if not
provided).
Returns:
A list of ast.alias representing imported aliases that aren't referenced in
the given tree.
"""
if sc is None:
sc = scope.analyze(tree)
unused_aliases = set()
for node in ast.walk(tree):
if isinstance(node, ast.alias):
str_name = node.asname if node.asname is not None else node.name
if str_name in sc.names:
name = sc.names[str_name]
if not name.reads:
unused_aliases.add(node)
else:
# This happens because of https://github.com/google/pasta/issues/32
logging.warning('Imported name %s not found in scope (perhaps it\'s '
'imported dynamically)', str_name)
return unused_aliases
def remove_import_alias_node(sc, node):
"""Remove an alias and if applicable remove their entire import.
Arguments:
sc: (scope.Scope) Scope computed on whole tree of the code being modified.
node: (ast.Import|ast.ImportFrom|ast.alias) The node to remove.
"""
import_node = sc.parent(node)
if len(import_node.names) == 1:
import_parent = sc.parent(import_node)
ast_utils.remove_child(import_parent, import_node)
else:
ast_utils.remove_child(import_node, node)
def remove_duplicates(tree, sc=None):
"""Remove duplicate imports, where it is safe to do so.
This does NOT remove imports that create new aliases
Arguments:
tree: (ast.AST) An ast to modify imports in.
sc: A scope.Scope representing tree (generated from scratch if not
provided).
Returns:
Whether any changes were made.
"""
if sc is None:
sc = scope.analyze(tree)
modified = False
seen_names = set()
for node in tree.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
for alias in list(node.names):
import_node = sc.parent(alias)
if isinstance(import_node, ast.Import):
full_name = alias.name
elif import_node.module:
full_name = '%s%s.%s' % ('.' * import_node.level,
import_node.module, alias.name)
else:
full_name = '%s%s' % ('.' * import_node.level, alias.name)
full_name += ':' + (alias.asname or alias.name)
if full_name in seen_names:
remove_import_alias_node(sc, alias)
modified = True
else:
seen_names.add(full_name)
return modified
@@ -0,0 +1,428 @@
# coding=utf-8
"""Tests for import_utils."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import traceback
import unittest
import pasta
from pasta.augment import import_utils
from pasta.base import ast_utils
from pasta.base import test_utils
from pasta.base import scope
class SplitImportTest(test_utils.TestCase):
def test_split_normal_import(self):
src = 'import aaa, bbb, ccc\n'
t = ast.parse(src)
import_node = t.body[0]
sc = scope.analyze(t)
import_utils.split_import(sc, import_node, import_node.names[1])
self.assertEqual(2, len(t.body))
self.assertEqual(ast.Import, type(t.body[1]))
self.assertEqual([alias.name for alias in t.body[0].names], ['aaa', 'ccc'])
self.assertEqual([alias.name for alias in t.body[1].names], ['bbb'])
def test_split_from_import(self):
src = 'from aaa import bbb, ccc, ddd\n'
t = ast.parse(src)
import_node = t.body[0]
sc = scope.analyze(t)
import_utils.split_import(sc, import_node, import_node.names[1])
self.assertEqual(2, len(t.body))
self.assertEqual(ast.ImportFrom, type(t.body[1]))
self.assertEqual(t.body[0].module, 'aaa')
self.assertEqual(t.body[1].module, 'aaa')
self.assertEqual([alias.name for alias in t.body[0].names], ['bbb', 'ddd'])
def test_split_imports_with_alias(self):
src = 'import aaa as a, bbb as b, ccc as c\n'
t = ast.parse(src)
import_node = t.body[0]
sc = scope.analyze(t)
import_utils.split_import(sc, import_node, import_node.names[1])
self.assertEqual(2, len(t.body))
self.assertEqual([alias.name for alias in t.body[0].names], ['aaa', 'ccc'])
self.assertEqual([alias.name for alias in t.body[1].names], ['bbb'])
self.assertEqual(t.body[1].names[0].asname, 'b')
def test_split_imports_multiple(self):
src = 'import aaa, bbb, ccc\n'
t = ast.parse(src)
import_node = t.body[0]
alias_bbb = import_node.names[1]
alias_ccc = import_node.names[2]
sc = scope.analyze(t)
import_utils.split_import(sc, import_node, alias_bbb)
import_utils.split_import(sc, import_node, alias_ccc)
self.assertEqual(3, len(t.body))
self.assertEqual([alias.name for alias in t.body[0].names], ['aaa'])
self.assertEqual([alias.name for alias in t.body[1].names], ['ccc'])
self.assertEqual([alias.name for alias in t.body[2].names], ['bbb'])
def test_split_nested_imports(self):
test_cases = (
'def foo():\n {import_stmt}\n',
'class Foo(object):\n {import_stmt}\n',
'if foo:\n {import_stmt}\nelse:\n pass\n',
'if foo:\n pass\nelse:\n {import_stmt}\n',
'if foo:\n pass\nelif bar:\n {import_stmt}\n',
'try:\n {import_stmt}\nexcept:\n pass\n',
'try:\n pass\nexcept:\n {import_stmt}\n',
'try:\n pass\nfinally:\n {import_stmt}\n',
'for i in foo:\n {import_stmt}\n',
'for i in foo:\n pass\nelse:\n {import_stmt}\n',
'while foo:\n {import_stmt}\n',
)
for template in test_cases:
try:
src = template.format(import_stmt='import aaa, bbb, ccc')
t = ast.parse(src)
sc = scope.analyze(t)
import_node = ast_utils.find_nodes_by_type(t, ast.Import)[0]
import_utils.split_import(sc, import_node, import_node.names[1])
split_import_nodes = ast_utils.find_nodes_by_type(t, ast.Import)
self.assertEqual(1, len(t.body))
self.assertEqual(2, len(split_import_nodes))
self.assertEqual([alias.name for alias in split_import_nodes[0].names],
['aaa', 'ccc'])
self.assertEqual([alias.name for alias in split_import_nodes[1].names],
['bbb'])
except:
self.fail('Failed while executing case:\n%s\nCaused by:\n%s' %
(src, traceback.format_exc()))
class GetUnusedImportsTest(test_utils.TestCase):
def test_normal_imports(self):
src = """\
import a
import b
a.foo()
"""
tree = ast.parse(src)
self.assertItemsEqual(import_utils.get_unused_import_aliases(tree),
[tree.body[1].names[0]])
def test_import_from(self):
src = """\
from my_module import a
import b
from my_module import c
b.foo()
c.bar()
"""
tree = ast.parse(src)
self.assertItemsEqual(import_utils.get_unused_import_aliases(tree),
[tree.body[0].names[0]])
def test_import_from_alias(self):
src = """\
from my_module import a, b
b.foo()
"""
tree = ast.parse(src)
self.assertItemsEqual(import_utils.get_unused_import_aliases(tree),
[tree.body[0].names[0]])
def test_import_asname(self):
src = """\
from my_module import a as a_mod, b as unused_b_mod
import c as c_mod, d as unused_d_mod
a_mod.foo()
c_mod.foo()
"""
tree = ast.parse(src)
self.assertItemsEqual(import_utils.get_unused_import_aliases(tree),
[tree.body[0].names[1],
tree.body[1].names[1]])
def test_dynamic_import(self):
# For now we just don't want to error out on these, longer
# term we want to do the right thing (see
# https://github.com/google/pasta/issues/32)
src = """\
def foo():
import bar
"""
tree = ast.parse(src)
self.assertItemsEqual(import_utils.get_unused_import_aliases(tree),
[])
class RemoveImportTest(test_utils.TestCase):
# Note that we don't test any 'asname' examples but as far as remove_import_alias_node
# is concerned its not a different case because its still just an alias type
# and we don't care about the internals of the alias we're trying to remove.
def test_remove_just_alias(self):
src = "import a, b"
tree = ast.parse(src)
sc = scope.analyze(tree)
unused_b_node = tree.body[0].names[1]
import_utils.remove_import_alias_node(sc, unused_b_node)
self.assertEqual(len(tree.body), 1)
self.assertEqual(type(tree.body[0]), ast.Import)
self.assertEqual(len(tree.body[0].names), 1)
self.assertEqual(tree.body[0].names[0].name, 'a')
def test_remove_just_alias_import_from(self):
src = "from m import a, b"
tree = ast.parse(src)
sc = scope.analyze(tree)
unused_b_node = tree.body[0].names[1]
import_utils.remove_import_alias_node(sc, unused_b_node)
self.assertEqual(len(tree.body), 1)
self.assertEqual(type(tree.body[0]), ast.ImportFrom)
self.assertEqual(len(tree.body[0].names), 1)
self.assertEqual(tree.body[0].names[0].name, 'a')
def test_remove_full_import(self):
src = "import a"
tree = ast.parse(src)
sc = scope.analyze(tree)
a_node = tree.body[0].names[0]
import_utils.remove_import_alias_node(sc, a_node)
self.assertEqual(len(tree.body), 0)
def test_remove_full_importfrom(self):
src = "from m import a"
tree = ast.parse(src)
sc = scope.analyze(tree)
a_node = tree.body[0].names[0]
import_utils.remove_import_alias_node(sc, a_node)
self.assertEqual(len(tree.body), 0)
class AddImportTest(test_utils.TestCase):
def test_add_normal_import(self):
tree = ast.parse('')
self.assertEqual('a.b.c',
import_utils.add_import(tree, 'a.b.c', from_import=False))
self.assertEqual('import a.b.c\n', pasta.dump(tree))
def test_add_normal_import_with_asname(self):
tree = ast.parse('')
self.assertEqual(
'd',
import_utils.add_import(tree, 'a.b.c', asname='d', from_import=False)
)
self.assertEqual('import a.b.c as d\n', pasta.dump(tree))
def test_add_from_import(self):
tree = ast.parse('')
self.assertEqual('c',
import_utils.add_import(tree, 'a.b.c', from_import=True))
self.assertEqual('from a.b import c\n', pasta.dump(tree))
def test_add_from_import_with_asname(self):
tree = ast.parse('')
self.assertEqual(
'd',
import_utils.add_import(tree, 'a.b.c', asname='d', from_import=True)
)
self.assertEqual('from a.b import c as d\n', pasta.dump(tree))
def test_add_single_name_from_import(self):
tree = ast.parse('')
self.assertEqual('foo',
import_utils.add_import(tree, 'foo', from_import=True))
self.assertEqual('import foo\n', pasta.dump(tree))
def test_add_single_name_from_import_with_asname(self):
tree = ast.parse('')
self.assertEqual(
'bar',
import_utils.add_import(tree, 'foo', asname='bar', from_import=True)
)
self.assertEqual('import foo as bar\n', pasta.dump(tree))
def test_add_existing_import(self):
tree = ast.parse('from a.b import c')
self.assertEqual('c', import_utils.add_import(tree, 'a.b.c'))
self.assertEqual('from a.b import c\n', pasta.dump(tree))
def test_add_existing_import_aliased(self):
tree = ast.parse('from a.b import c as d')
self.assertEqual('d', import_utils.add_import(tree, 'a.b.c'))
self.assertEqual('from a.b import c as d\n', pasta.dump(tree))
def test_add_existing_import_aliased_with_asname(self):
tree = ast.parse('from a.b import c as d')
self.assertEqual('d', import_utils.add_import(tree, 'a.b.c', asname='e'))
self.assertEqual('from a.b import c as d\n', pasta.dump(tree))
def test_add_existing_import_normal_import(self):
tree = ast.parse('import a.b.c')
self.assertEqual('a.b',
import_utils.add_import(tree, 'a.b', from_import=False))
self.assertEqual('import a.b.c\n', pasta.dump(tree))
def test_add_existing_import_normal_import_aliased(self):
tree = ast.parse('import a.b.c as d')
self.assertEqual('a.b',
import_utils.add_import(tree, 'a.b', from_import=False))
self.assertEqual('d',
import_utils.add_import(tree, 'a.b.c', from_import=False))
self.assertEqual('import a.b\nimport a.b.c as d\n', pasta.dump(tree))
def test_add_import_with_conflict(self):
tree = ast.parse('def c(): pass\n')
self.assertEqual('c_1',
import_utils.add_import(tree, 'a.b.c', from_import=True))
self.assertEqual(
'from a.b import c as c_1\ndef c():\n pass\n', pasta.dump(tree))
def test_add_import_with_asname_with_conflict(self):
tree = ast.parse('def c(): pass\n')
self.assertEqual('c_1',
import_utils.add_import(tree, 'a.b', asname='c', from_import=True))
self.assertEqual(
'from a import b as c_1\ndef c():\n pass\n', pasta.dump(tree))
def test_merge_from_import(self):
tree = ast.parse('from a.b import c')
# x is explicitly not merged
self.assertEqual('x', import_utils.add_import(tree, 'a.b.x',
merge_from_imports=False))
self.assertEqual('from a.b import x\nfrom a.b import c\n',
pasta.dump(tree))
# y is allowed to be merged and is grouped into the first matching import
self.assertEqual('y', import_utils.add_import(tree, 'a.b.y',
merge_from_imports=True))
self.assertEqual('from a.b import x, y\nfrom a.b import c\n',
pasta.dump(tree))
def test_add_import_after_docstring(self):
tree = ast.parse('\'Docstring.\'')
self.assertEqual('a', import_utils.add_import(tree, 'a'))
self.assertEqual('\'Docstring.\'\nimport a\n', pasta.dump(tree))
class RemoveDuplicatesTest(test_utils.TestCase):
def test_remove_duplicates(self):
src = """
import a
import b
import c
import b
import d
"""
tree = ast.parse(src)
self.assertTrue(import_utils.remove_duplicates(tree))
self.assertEqual(len(tree.body), 4)
self.assertEqual(tree.body[0].names[0].name, 'a')
self.assertEqual(tree.body[1].names[0].name, 'b')
self.assertEqual(tree.body[2].names[0].name, 'c')
self.assertEqual(tree.body[3].names[0].name, 'd')
def test_remove_duplicates_multiple(self):
src = """
import a, b
import b, c
import d, a, e, f
"""
tree = ast.parse(src)
self.assertTrue(import_utils.remove_duplicates(tree))
self.assertEqual(len(tree.body), 3)
self.assertEqual(len(tree.body[0].names), 2)
self.assertEqual(tree.body[0].names[0].name, 'a')
self.assertEqual(tree.body[0].names[1].name, 'b')
self.assertEqual(len(tree.body[1].names), 1)
self.assertEqual(tree.body[1].names[0].name, 'c')
self.assertEqual(len(tree.body[2].names), 3)
self.assertEqual(tree.body[2].names[0].name, 'd')
self.assertEqual(tree.body[2].names[1].name, 'e')
self.assertEqual(tree.body[2].names[2].name, 'f')
def test_remove_duplicates_empty_node(self):
src = """
import a, b, c
import b, c
"""
tree = ast.parse(src)
self.assertTrue(import_utils.remove_duplicates(tree))
self.assertEqual(len(tree.body), 1)
self.assertEqual(len(tree.body[0].names), 3)
self.assertEqual(tree.body[0].names[0].name, 'a')
self.assertEqual(tree.body[0].names[1].name, 'b')
self.assertEqual(tree.body[0].names[2].name, 'c')
def test_remove_duplicates_normal_and_from(self):
src = """
import a.b
from a import b
"""
tree = ast.parse(src)
self.assertFalse(import_utils.remove_duplicates(tree))
self.assertEqual(len(tree.body), 2)
def test_remove_duplicates_aliases(self):
src = """
import a
import a as ax
import a as ax2
import a as ax
"""
tree = ast.parse(src)
self.assertTrue(import_utils.remove_duplicates(tree))
self.assertEqual(len(tree.body), 3)
self.assertEqual(tree.body[0].names[0].asname, None)
self.assertEqual(tree.body[1].names[0].asname, 'ax')
self.assertEqual(tree.body[2].names[0].asname, 'ax2')
def suite():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(SplitImportTest))
result.addTests(unittest.makeSuite(GetUnusedImportsTest))
result.addTests(unittest.makeSuite(RemoveImportTest))
result.addTests(unittest.makeSuite(AddImportTest))
result.addTests(unittest.makeSuite(RemoveDuplicatesTest))
return result
if __name__ == '__main__':
unittest.main()
@@ -0,0 +1,65 @@
# coding=utf-8
"""Inline constants in a python module."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import copy
from pasta.base import ast_utils
from pasta.base import scope
class InlineError(Exception):
pass
def inline_name(t, name):
"""Inline a constant name into a module."""
sc = scope.analyze(t)
name_node = sc.names[name]
# The name must be a Name node (not a FunctionDef, etc.)
if not isinstance(name_node.definition, ast.Name):
raise InlineError('%r is not a constant; it has type %r' % (
name, type(name_node.definition)))
assign_node = sc.parent(name_node.definition)
if not isinstance(assign_node, ast.Assign):
raise InlineError('%r is not declared in an assignment' % name)
value = assign_node.value
if not isinstance(sc.parent(assign_node), ast.Module):
raise InlineError('%r is not a top-level name' % name)
# If the name is written anywhere else in this module, it is not constant
for ref in name_node.reads:
if isinstance(getattr(ref, 'ctx', None), ast.Store):
raise InlineError('%r is not a constant' % name)
# Replace all reads of the name with a copy of its value
for ref in name_node.reads:
ast_utils.replace_child(sc.parent(ref), ref, copy.deepcopy(value))
# Remove the assignment to this name
if len(assign_node.targets) == 1:
ast_utils.remove_child(sc.parent(assign_node), assign_node)
else:
tgt_list = [tgt for tgt in assign_node.targets
if not (isinstance(tgt, ast.Name) and tgt.id == name)]
assign_node.targets = tgt_list
@@ -0,0 +1,97 @@
# coding=utf-8
"""Tests for augment.inline."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import textwrap
import unittest
from pasta.augment import inline
from pasta.base import test_utils
class InlineTest(test_utils.TestCase):
def test_inline_simple(self):
src = 'x = 1\na = x\n'
t = ast.parse(src)
inline.inline_name(t, 'x')
self.checkAstsEqual(t, ast.parse('a = 1\n'))
def test_inline_multiple_targets(self):
src = 'x = y = z = 1\na = x + y\n'
t = ast.parse(src)
inline.inline_name(t, 'y')
self.checkAstsEqual(t, ast.parse('x = z = 1\na = x + 1\n'))
def test_inline_multiple_reads(self):
src = textwrap.dedent('''\
CONSTANT = "foo"
def a(b=CONSTANT):
return b == CONSTANT
''')
expected = textwrap.dedent('''\
def a(b="foo"):
return b == "foo"
''')
t = ast.parse(src)
inline.inline_name(t, 'CONSTANT')
self.checkAstsEqual(t, ast.parse(expected))
def test_inline_non_constant_fails(self):
src = textwrap.dedent('''\
NOT_A_CONSTANT = "foo"
NOT_A_CONSTANT += "bar"
''')
t = ast.parse(src)
with self.assertRaisesRegexp(inline.InlineError,
'\'NOT_A_CONSTANT\' is not a constant'):
inline.inline_name(t, 'NOT_A_CONSTANT')
def test_inline_function_fails(self):
src = 'def func(): pass\nfunc()\n'
t = ast.parse(src)
with self.assertRaisesRegexp(
inline.InlineError,
'\'func\' is not a constant; it has type %r' % ast.FunctionDef):
inline.inline_name(t, 'func')
def test_inline_conditional_fails(self):
src = 'if define:\n x = 1\na = x\n'
t = ast.parse(src)
with self.assertRaisesRegexp(inline.InlineError,
'\'x\' is not a top-level name'):
inline.inline_name(t, 'x')
def test_inline_non_assign_fails(self):
src = 'CONSTANT1, CONSTANT2 = values'
t = ast.parse(src)
with self.assertRaisesRegexp(
inline.InlineError, '\'CONSTANT1\' is not declared in an assignment'):
inline.inline_name(t, 'CONSTANT1')
def suite():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(InlineTest))
return result
if __name__ == '__main__':
unittest.main()
@@ -0,0 +1,154 @@
# coding=utf-8
"""Rename names in a python module."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import six
from pasta.augment import import_utils
from pasta.base import ast_utils
from pasta.base import scope
def rename_external(t, old_name, new_name):
"""Rename an imported name in a module.
This will rewrite all import statements in `tree` that reference the old
module as well as any names in `tree` which reference the imported name. This
may introduce new import statements, but only if necessary.
For example, to move and rename the module `foo.bar.utils` to `foo.bar_utils`:
> rename_external(tree, 'foo.bar.utils', 'foo.bar_utils')
- import foo.bar.utils
+ import foo.bar_utils
- from foo.bar import utils
+ from foo import bar_utils
- from foo.bar import logic, utils
+ from foo.bar import logic
+ from foo import bar_utils
Arguments:
t: (ast.Module) Module syntax tree to perform the rename in. This will be
updated as a result of this function call with all affected nodes changed
and potentially new Import/ImportFrom nodes added.
old_name: (string) Fully-qualified path of the name to replace.
new_name: (string) Fully-qualified path of the name to update to.
Returns:
True if any changes were made, False otherwise.
"""
sc = scope.analyze(t)
if old_name not in sc.external_references:
return False
has_changed = False
renames = {}
already_changed = []
for ref in sc.external_references[old_name]:
if isinstance(ref.node, ast.alias):
parent = sc.parent(ref.node)
# An alias may be the most specific reference to an imported name, but it
# could if it is a child of an ImportFrom, the ImportFrom node's module
# may also need to be updated.
if isinstance(parent, ast.ImportFrom) and parent not in already_changed:
assert _rename_name_in_importfrom(sc, parent, old_name, new_name)
renames[old_name.rsplit('.', 1)[-1]] = new_name.rsplit('.', 1)[-1]
already_changed.append(parent)
else:
ref.node.name = new_name + ref.node.name[len(old_name):]
if not ref.node.asname:
renames[old_name] = new_name
has_changed = True
elif isinstance(ref.node, ast.ImportFrom):
if ref.node not in already_changed:
assert _rename_name_in_importfrom(sc, ref.node, old_name, new_name)
renames[old_name.rsplit('.', 1)[-1]] = new_name.rsplit('.', 1)[-1]
already_changed.append(ref.node)
has_changed = True
for rename_old, rename_new in six.iteritems(renames):
_rename_reads(sc, t, rename_old, rename_new)
return has_changed
def _rename_name_in_importfrom(sc, node, old_name, new_name):
if old_name == new_name:
return False
module_parts = node.module.split('.')
old_parts = old_name.split('.')
new_parts = new_name.split('.')
# If just the module is changing, rename it
if module_parts[:len(old_parts)] == old_parts:
node.module = '.'.join(new_parts + module_parts[len(old_parts):])
return True
# Find the alias node to be changed
for alias_to_change in node.names:
if alias_to_change.name == old_parts[-1]:
break
else:
return False
alias_to_change.name = new_parts[-1]
# Split the import if the package has changed
if module_parts != new_parts[:-1]:
if len(node.names) > 1:
new_import = import_utils.split_import(sc, node, alias_to_change)
new_import.module = '.'.join(new_parts[:-1])
else:
node.module = '.'.join(new_parts[:-1])
return True
def _rename_reads(sc, t, old_name, new_name):
"""Updates all locations in the module where the given name is read.
Arguments:
sc: (scope.Scope) Scope to work in. This should be the scope of `t`.
t: (ast.AST) The AST to perform updates in.
old_name: (string) Dotted name to update.
new_name: (string) Dotted name to replace it with.
Returns:
True if any changes were made, False otherwise.
"""
name_parts = old_name.split('.')
try:
name = sc.names[name_parts[0]]
for part in name_parts[1:]:
name = name.attrs[part]
except KeyError:
return False
has_changed = False
for ref_node in name.reads:
if isinstance(ref_node, (ast.Name, ast.Attribute)):
ast_utils.replace_child(sc.parent(ref_node), ref_node,
ast.parse(new_name).body[0].value)
has_changed = True
return has_changed
@@ -0,0 +1,119 @@
# coding=utf-8
"""Tests for augment.rename."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import unittest
from pasta.augment import rename
from pasta.base import scope
from pasta.base import test_utils
class RenameTest(test_utils.TestCase):
def test_rename_external_in_import(self):
src = 'import aaa.bbb.ccc\naaa.bbb.ccc.foo()'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('import xxx.yyy.ccc\nxxx.yyy.ccc.foo()'))
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb.ccc', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('import xxx.yyy\nxxx.yyy.foo()'))
t = ast.parse(src)
self.assertFalse(rename.rename_external(t, 'bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse(src))
def test_rename_external_in_import_with_asname(self):
src = 'import aaa.bbb.ccc as ddd\nddd.foo()'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('import xxx.yyy.ccc as ddd\nddd.foo()'))
def test_rename_external_in_import_multiple_aliases(self):
src = 'import aaa, aaa.bbb, aaa.bbb.ccc'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('import aaa, xxx.yyy, xxx.yyy.ccc'))
def test_rename_external_in_importfrom(self):
src = 'from aaa.bbb.ccc import ddd\nddd.foo()'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('from xxx.yyy.ccc import ddd\nddd.foo()'))
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb.ccc', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('from xxx.yyy import ddd\nddd.foo()'))
t = ast.parse(src)
self.assertFalse(rename.rename_external(t, 'bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse(src))
def test_rename_external_in_importfrom_alias(self):
src = 'from aaa.bbb import ccc\nccc.foo()'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb.ccc', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('from xxx import yyy\nyyy.foo()'))
def test_rename_external_in_importfrom_alias_with_asname(self):
src = 'from aaa.bbb import ccc as abc\nabc.foo()'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb.ccc', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('from xxx import yyy as abc\nabc.foo()'))
def test_rename_reads_name(self):
src = 'aaa.bbb()'
t = ast.parse(src)
sc = scope.analyze(t)
self.assertTrue(rename._rename_reads(sc, t, 'aaa', 'xxx'))
self.checkAstsEqual(t, ast.parse('xxx.bbb()'))
def test_rename_reads_name_as_attribute(self):
src = 'aaa.bbb()'
t = ast.parse(src)
sc = scope.analyze(t)
rename._rename_reads(sc, t, 'aaa', 'xxx.yyy')
self.checkAstsEqual(t, ast.parse('xxx.yyy.bbb()'))
def test_rename_reads_attribute(self):
src = 'aaa.bbb.ccc()'
t = ast.parse(src)
sc = scope.analyze(t)
rename._rename_reads(sc, t, 'aaa.bbb', 'xxx.yyy')
self.checkAstsEqual(t, ast.parse('xxx.yyy.ccc()'))
def test_rename_reads_noop(self):
src = 'aaa.bbb.ccc()'
t = ast.parse(src)
sc = scope.analyze(t)
rename._rename_reads(sc, t, 'aaa.bbb.ccc.ddd', 'xxx.yyy')
rename._rename_reads(sc, t, 'bbb.aaa', 'xxx.yyy')
self.checkAstsEqual(t, ast.parse(src))
def suite():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(RenameTest))
return result
if __name__ == '__main__':
unittest.main()
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,477 @@
# coding=utf-8
"""Tests for annotate."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import difflib
import itertools
import os.path
from six import with_metaclass
import sys
import textwrap
import unittest
import pasta
from pasta.base import annotate
from pasta.base import ast_utils
from pasta.base import codegen
from pasta.base import formatting as fmt
from pasta.base import test_utils
TESTDATA_DIR = os.path.realpath(
os.path.join(os.path.dirname(pasta.__file__), '../testdata'))
class PrefixSuffixTest(test_utils.TestCase):
def test_block_suffix(self):
src_tpl = textwrap.dedent('''\
{open_block}
pass #a
#b
#c
#d
#e
a
''')
test_cases = (
# first: attribute of the node with the last block
# second: code snippet to open a block
('body', 'def x():'),
('body', 'class X:'),
('body', 'if x:'),
('orelse', 'if x:\n y\nelse:'),
('body', 'if x:\n y\nelif y:'),
('body', 'while x:'),
('orelse', 'while x:\n y\nelse:'),
('finalbody', 'try:\n x\nfinally:'),
('body', 'try:\n x\nexcept:'),
('orelse', 'try:\n x\nexcept:\n y\nelse:'),
('body', 'with x:'),
('body', 'with x, y:'),
('body', 'with x:\n with y:'),
('body', 'for x in y:'),
)
def is_node_for_suffix(node, children_attr):
# Return True if this node contains the 'pass' statement
val = getattr(node, children_attr, None)
return isinstance(val, list) and type(val[0]) == ast.Pass
for children_attr, open_block in test_cases:
src = src_tpl.format(open_block=open_block)
t = pasta.parse(src)
node_finder = ast_utils.FindNodeVisitor(
lambda node: is_node_for_suffix(node, children_attr))
node_finder.visit(t)
node = node_finder.results[0]
expected = ' #b\n #c\n\n #d\n'
actual = str(fmt.get(node, 'block_suffix_%s' % children_attr))
self.assertMultiLineEqual(
expected, actual,
'Incorrect suffix for code:\n%s\nNode: %s (line %d)\nDiff:\n%s' % (
src, node, node.lineno, '\n'.join(_get_diff(actual, expected))))
self.assertMultiLineEqual(src, pasta.dump(t))
def test_module_suffix(self):
src = 'foo\n#bar\n\n#baz\n'
t = pasta.parse(src)
self.assertEqual(src[src.index('#bar'):], fmt.get(t, 'suffix'))
def test_no_block_suffix_for_single_line_statement(self):
src = 'if x: return y\n #a\n#b\n'
t = pasta.parse(src)
self.assertIsNone(fmt.get(t.body[0], 'block_suffix_body'))
def test_expression_prefix_suffix(self):
src = 'a\n\nfoo\n\n\nb\n'
t = pasta.parse(src)
self.assertEqual('\n', fmt.get(t.body[1], 'prefix'))
self.assertEqual('\n', fmt.get(t.body[1], 'suffix'))
def test_statement_prefix_suffix(self):
src = 'a\n\ndef foo():\n return bar\n\n\nb\n'
t = pasta.parse(src)
self.assertEqual('\n', fmt.get(t.body[1], 'prefix'))
self.assertEqual('', fmt.get(t.body[1], 'suffix'))
class IndentationTest(test_utils.TestCase):
def test_indent_levels(self):
src = textwrap.dedent('''\
foo('begin')
if a:
foo('a1')
if b:
foo('b1')
if c:
foo('c1')
foo('b2')
foo('a2')
foo('end')
''')
t = pasta.parse(src)
call_nodes = ast_utils.find_nodes_by_type(t, (ast.Call,))
call_nodes.sort(key=lambda node: node.lineno)
begin, a1, b1, c1, b2, a2, end = call_nodes
self.assertEqual('', fmt.get(begin, 'indent'))
self.assertEqual(' ', fmt.get(a1, 'indent'))
self.assertEqual(' ', fmt.get(b1, 'indent'))
self.assertEqual(' ', fmt.get(c1, 'indent'))
self.assertEqual(' ', fmt.get(b2, 'indent'))
self.assertEqual(' ', fmt.get(a2, 'indent'))
self.assertEqual('', fmt.get(end, 'indent'))
def test_indent_levels_same_line(self):
src = 'if a: b; c\n'
t = pasta.parse(src)
if_node = t.body[0]
b, c = if_node.body
self.assertIsNone(fmt.get(b, 'indent_diff'))
self.assertIsNone(fmt.get(c, 'indent_diff'))
def test_indent_depths(self):
template = 'if a:\n{first}if b:\n{first}{second}foo()\n'
indents = (' ', ' ' * 2, ' ' * 4, ' ' * 8, '\t', '\t' * 2)
for first, second in itertools.product(indents, indents):
src = template.format(first=first, second=second)
t = pasta.parse(src)
outer_if_node = t.body[0]
inner_if_node = outer_if_node.body[0]
call_node = inner_if_node.body[0]
self.assertEqual('', fmt.get(outer_if_node, 'indent'))
self.assertEqual('', fmt.get(outer_if_node, 'indent_diff'))
self.assertEqual(first, fmt.get(inner_if_node, 'indent'))
self.assertEqual(first, fmt.get(inner_if_node, 'indent_diff'))
self.assertEqual(first + second, fmt.get(call_node, 'indent'))
self.assertEqual(second, fmt.get(call_node, 'indent_diff'))
def test_indent_multiline_string(self):
src = textwrap.dedent('''\
class A:
"""Doc
string."""
pass
''')
t = pasta.parse(src)
docstring, pass_stmt = t.body[0].body
self.assertEqual(' ', fmt.get(docstring, 'indent'))
self.assertEqual(' ', fmt.get(pass_stmt, 'indent'))
def test_indent_multiline_string_with_newline(self):
src = textwrap.dedent('''\
class A:
"""Doc\n
string."""
pass
''')
t = pasta.parse(src)
docstring, pass_stmt = t.body[0].body
self.assertEqual(' ', fmt.get(docstring, 'indent'))
self.assertEqual(' ', fmt.get(pass_stmt, 'indent'))
def test_scope_trailing_comma(self):
template = 'def foo(a, b{trailing_comma}): pass'
for trailing_comma in ('', ',', ' , '):
tree = pasta.parse(template.format(trailing_comma=trailing_comma))
self.assertEqual(trailing_comma.lstrip(' ') + ')',
fmt.get(tree.body[0], 'args_suffix'))
template = 'class Foo(a, b{trailing_comma}): pass'
for trailing_comma in ('', ',', ' , '):
tree = pasta.parse(template.format(trailing_comma=trailing_comma))
self.assertEqual(trailing_comma.lstrip(' ') + ')',
fmt.get(tree.body[0], 'bases_suffix'))
template = 'from mod import (a, b{trailing_comma})'
for trailing_comma in ('', ',', ' , '):
tree = pasta.parse(template.format(trailing_comma=trailing_comma))
self.assertEqual(trailing_comma + ')',
fmt.get(tree.body[0], 'names_suffix'))
def test_indent_extra_newlines(self):
src = textwrap.dedent('''\
if a:
b
''')
t = pasta.parse(src)
if_node = t.body[0]
b = if_node.body[0]
self.assertEqual(' ', fmt.get(b, 'indent_diff'))
def test_indent_extra_newlines_with_comment(self):
src = textwrap.dedent('''\
if a:
#not here
b
''')
t = pasta.parse(src)
if_node = t.body[0]
b = if_node.body[0]
self.assertEqual(' ', fmt.get(b, 'indent_diff'))
def test_autoindent(self):
src = textwrap.dedent('''\
def a():
b
c
''')
expected = textwrap.dedent('''\
def a():
b
new_node
''')
t = pasta.parse(src)
# Repace the second node and make sure the indent level is corrected
t.body[0].body[1] = ast.Expr(ast.Name(id='new_node'))
self.assertMultiLineEqual(expected, codegen.to_str(t))
@test_utils.requires_features('mixed_tabs_spaces')
def test_mixed_tabs_spaces_indentation(self):
pasta.parse(textwrap.dedent('''\
if a:
b
{ONETAB}c
''').format(ONETAB='\t'))
@test_utils.requires_features('mixed_tabs_spaces')
def test_tab_below_spaces(self):
for num_spaces in range(1, 8):
t = pasta.parse(textwrap.dedent('''\
if a:
{WS}if b:
{ONETAB}c
''').format(ONETAB='\t', WS=' ' * num_spaces))
node_c = t.body[0].body[0].body[0]
self.assertEqual(fmt.get(node_c, 'indent_diff'), ' ' * (8 - num_spaces))
@test_utils.requires_features('mixed_tabs_spaces')
def test_tabs_below_spaces_and_tab(self):
for num_spaces in range(1, 8):
t = pasta.parse(textwrap.dedent('''\
if a:
{WS}{ONETAB}if b:
{ONETAB}{ONETAB}c
''').format(ONETAB='\t', WS=' ' * num_spaces))
node_c = t.body[0].body[0].body[0]
self.assertEqual(fmt.get(node_c, 'indent_diff'), '\t')
def _is_syntax_valid(filepath):
with open(filepath, 'r') as f:
try:
ast.parse(f.read())
except SyntaxError:
return False
return True
class SymmetricTestMeta(type):
def __new__(mcs, name, bases, inst_dict):
# Helper function to generate a test method
def symmetric_test_generator(filepath):
def test(self):
with open(filepath, 'r') as handle:
src = handle.read()
t = ast_utils.parse(src)
annotator = annotate.AstAnnotator(src)
annotator.visit(t)
self.assertMultiLineEqual(codegen.to_str(t), src)
self.assertEqual([], annotator.tokens._parens, 'Unmatched parens')
return test
# Add a test method for each input file
test_method_prefix = 'test_symmetric_'
data_dir = os.path.join(TESTDATA_DIR, 'ast')
for dirpath, dirs, files in os.walk(data_dir):
for filename in files:
if filename.endswith('.in'):
full_path = os.path.join(dirpath, filename)
inst_dict[test_method_prefix + filename[:-3]] = unittest.skipIf(
not _is_syntax_valid(full_path),
'Test contains syntax not supported by this version.',
)(symmetric_test_generator(full_path))
return type.__new__(mcs, name, bases, inst_dict)
class SymmetricTest(with_metaclass(SymmetricTestMeta, test_utils.TestCase)):
"""Validates the symmetry property.
After parsing + annotating a module, regenerating the source code for it
should yield the same result.
"""
def _get_node_identifier(node):
for attr in ('id', 'name', 'attr', 'arg', 'module'):
if isinstance(getattr(node, attr, None), str):
return getattr(node, attr, '')
return ''
class PrefixSuffixGoldenTestMeta(type):
def __new__(mcs, name, bases, inst_dict):
# Helper function to generate a test method
def golden_test_generator(input_file, golden_file):
def test(self):
with open(input_file, 'r') as handle:
src = handle.read()
t = ast_utils.parse(src)
annotator = annotate.AstAnnotator(src)
annotator.visit(t)
def escape(s):
return '' if s is None else s.replace('\n', '\\n')
result = '\n'.join(
"{0:12} {1:20} \tprefix=|{2}|\tsuffix=|{3}|\tindent=|{4}|".format(
str((getattr(n, 'lineno', -1), getattr(n, 'col_offset', -1))),
type(n).__name__ + ' ' + _get_node_identifier(n),
escape(fmt.get(n, 'prefix')),
escape(fmt.get(n, 'suffix')),
escape(fmt.get(n, 'indent')))
for n in ast.walk(t)) + '\n'
# If specified, write the golden data instead of checking it
if getattr(self, 'generate_goldens', False):
if not os.path.isdir(os.path.dirname(golden_file)):
os.makedirs(os.path.dirname(golden_file))
with open(golden_file, 'w') as f:
f.write(result)
print('Wrote: ' + golden_file)
return
try:
with open(golden_file, 'r') as f:
golden = f.read()
except IOError:
self.fail('Missing golden data.')
self.assertMultiLineEqual(golden, result)
return test
# Add a test method for each input file
test_method_prefix = 'test_golden_prefix_suffix_'
data_dir = os.path.join(TESTDATA_DIR, 'ast')
python_version = '%d.%d' % sys.version_info[:2]
for dirpath, dirs, files in os.walk(data_dir):
for filename in files:
if filename.endswith('.in'):
full_path = os.path.join(dirpath, filename)
golden_path = os.path.join(dirpath, 'golden', python_version,
filename[:-3] + '.out')
inst_dict[test_method_prefix + filename[:-3]] = unittest.skipIf(
not _is_syntax_valid(full_path),
'Test contains syntax not supported by this version.',
)(golden_test_generator(full_path, golden_path))
return type.__new__(mcs, name, bases, inst_dict)
class PrefixSuffixGoldenTest(with_metaclass(PrefixSuffixGoldenTestMeta,
test_utils.TestCase)):
"""Checks the prefix and suffix on each node in the AST.
This uses golden files in testdata/ast/golden. To regenerate these files, run
python setup.py test -s pasta.base.annotate_test.generate_goldens
"""
maxDiff = None
class ManualEditsTest(test_utils.TestCase):
"""Tests that we can handle ASTs that have been modified.
Such ASTs may lack position information (lineno/col_offset) on some nodes.
"""
def test_call_no_pos(self):
"""Tests that Call node traversal works without position information."""
src = 'f(a)'
t = pasta.parse(src)
node = ast_utils.find_nodes_by_type(t, (ast.Call,))[0]
node.keywords.append(ast.keyword(arg='b', value=ast.Num(n=0)))
self.assertEqual('f(a, b=0)', pasta.dump(t))
def test_call_illegal_pos(self):
"""Tests that Call node traversal works even with illegal positions."""
src = 'f(a)'
t = pasta.parse(src)
node = ast_utils.find_nodes_by_type(t, (ast.Call,))[0]
node.keywords.append(ast.keyword(arg='b', value=ast.Num(n=0)))
# This position would put b=0 before a, so it should be ignored.
node.keywords[-1].value.lineno = 0
node.keywords[-1].value.col_offset = 0
self.assertEqual('f(a, b=0)', pasta.dump(t))
class FstringTest(test_utils.TestCase):
"""Tests fstring support more in-depth."""
@test_utils.requires_features('fstring')
def test_fstring(self):
src = 'f"a {b} c d {e}"'
t = pasta.parse(src)
node = t.body[0].value
self.assertEqual(
fmt.get(node, 'content'),
'f"a {__pasta_fstring_val_0__} c d {__pasta_fstring_val_1__}"')
@test_utils.requires_features('fstring')
def test_fstring_escaping(self):
src = 'f"a {{{b} {{c}}"'
t = pasta.parse(src)
node = t.body[0].value
self.assertEqual(
fmt.get(node, 'content'),
'f"a {{{__pasta_fstring_val_0__} {{c}}"')
def _get_diff(before, after):
return difflib.ndiff(after.splitlines(), before.splitlines())
def suite():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(ManualEditsTest))
result.addTests(unittest.makeSuite(SymmetricTest))
result.addTests(unittest.makeSuite(PrefixSuffixTest))
result.addTests(unittest.makeSuite(PrefixSuffixGoldenTest))
result.addTests(unittest.makeSuite(FstringTest))
return result
def generate_goldens():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(PrefixSuffixGoldenTest))
setattr(PrefixSuffixGoldenTest, 'generate_goldens', True)
return result
if __name__ == '__main__':
unittest.main()
@@ -0,0 +1,38 @@
"""Constants relevant to ast code."""
import ast
NODE_TYPE_TO_TOKENS = {
ast.Add: ('+',),
ast.And: ('and',),
ast.BitAnd: ('&',),
ast.BitOr: ('|',),
ast.BitXor: ('^',),
ast.Div: ('/',),
ast.Eq: ('==',),
ast.FloorDiv: ('//',),
ast.Gt: ('>',),
ast.GtE: ('>=',),
ast.In: ('in',),
ast.Invert: ('~',),
ast.Is: ('is',),
ast.IsNot: ('is', 'not',),
ast.LShift: ('<<',),
ast.Lt: ('<',),
ast.LtE: ('<=',),
ast.Mod: ('%',),
ast.Mult: ('*',),
ast.Not: ('not',),
ast.NotEq: ('!=',),
ast.NotIn: ('not', 'in',),
ast.Or: ('or',),
ast.Pow: ('**',),
ast.RShift: ('>>',),
ast.Sub: ('-',),
ast.UAdd: ('+',),
ast.USub: ('-',),
}
if hasattr(ast, 'MatMult'):
NODE_TYPE_TO_TOKENS[ast.MatMult] = ('@',)
@@ -0,0 +1,179 @@
# coding=utf-8
"""Helpers for working with python ASTs."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import re
from pasta.augment import errors
from pasta.base import formatting as fmt
# From PEP-0263 -- https://www.python.org/dev/peps/pep-0263/
_CODING_PATTERN = re.compile('^[ \t\v]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)')
_AST_OP_NODES = (
ast.And, ast.Or, ast.Eq, ast.NotEq, ast.Is, ast.IsNot, ast.In, ast.NotIn,
ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Add, ast.Sub, ast.Mult, ast.Div,
ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitAnd, ast.BitOr, ast.BitXor,
ast.FloorDiv, ast.Invert, ast.Not, ast.UAdd, ast.USub
)
class _TreeNormalizer(ast.NodeTransformer):
"""Replaces all op nodes with unique instances."""
def visit(self, node):
if isinstance(node, _AST_OP_NODES):
return node.__class__()
return super(_TreeNormalizer, self).visit(node)
_tree_normalizer = _TreeNormalizer()
def parse(src):
"""Replaces ast.parse; ensures additional properties on the parsed tree.
This enforces the assumption that each node in the ast is unique.
"""
tree = ast.parse(sanitize_source(src))
_tree_normalizer.visit(tree)
return tree
def sanitize_source(src):
"""Strip the 'coding' directive from python source code, if present.
This is a workaround for https://bugs.python.org/issue18960. Also see PEP-0263.
"""
src_lines = src.splitlines(True)
for i, line in enumerate(src_lines[:2]):
if _CODING_PATTERN.match(line):
src_lines[i] = re.sub('#.*$', '# (removed coding)', line)
return ''.join(src_lines)
def find_nodes_by_type(node, accept_types):
visitor = FindNodeVisitor(lambda n: isinstance(n, accept_types))
visitor.visit(node)
return visitor.results
class FindNodeVisitor(ast.NodeVisitor):
def __init__(self, condition):
self._condition = condition
self.results = []
def visit(self, node):
if self._condition(node):
self.results.append(node)
super(FindNodeVisitor, self).visit(node)
def get_last_child(node):
"""Get the last child node of a block statement.
The input must be a block statement (e.g. ast.For, ast.With, etc).
Examples:
1. with first():
second()
last()
2. try:
first()
except:
second()
finally:
last()
In both cases, the last child is the node for `last`.
"""
if isinstance(node, ast.Module):
try:
return node.body[-1]
except IndexError:
return None
if isinstance(node, ast.If):
if (len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If) and
fmt.get(node.orelse[0], 'is_elif')):
return get_last_child(node.orelse[0])
if node.orelse:
return node.orelse[-1]
elif isinstance(node, ast.With):
if (len(node.body) == 1 and isinstance(node.body[0], ast.With) and
fmt.get(node.body[0], 'is_continued')):
return get_last_child(node.body[0])
elif hasattr(ast, 'Try') and isinstance(node, ast.Try):
if node.finalbody:
return node.finalbody[-1]
if node.orelse:
return node.orelse[-1]
elif hasattr(ast, 'TryFinally') and isinstance(node, ast.TryFinally):
if node.finalbody:
return node.finalbody[-1]
elif hasattr(ast, 'TryExcept') and isinstance(node, ast.TryExcept):
if node.orelse:
return node.orelse[-1]
if node.handlers:
return get_last_child(node.handlers[-1])
return node.body[-1]
def remove_child(parent, child):
for _, field_value in ast.iter_fields(parent):
if isinstance(field_value, list) and child in field_value:
field_value.remove(child)
return
raise errors.InvalidAstError('Unable to find list containing child %r on '
'parent node %r' % (child, parent))
def replace_child(parent, node, replace_with):
"""Replace a node's child with another node while preserving formatting.
Arguments:
parent: (ast.AST) Parent node to replace a child of.
node: (ast.AST) Child node to replace.
replace_with: (ast.AST) New child node.
"""
# TODO(soupytwist): Don't refer to the formatting dict directly
if hasattr(node, fmt.PASTA_DICT):
fmt.set(replace_with, 'prefix', fmt.get(node, 'prefix'))
fmt.set(replace_with, 'suffix', fmt.get(node, 'suffix'))
for field in parent._fields:
field_val = getattr(parent, field, None)
if field_val == node:
setattr(parent, field, replace_with)
return
elif isinstance(field_val, list):
try:
field_val[field_val.index(node)] = replace_with
return
except ValueError:
pass
raise errors.InvalidAstError('Node %r is not a child of %r' % (node, parent))
def has_docstring(node):
return (hasattr(node, 'body') and node.body and
isinstance(node.body[0], ast.Expr) and
isinstance(node.body[0].value, ast.Str))
@@ -0,0 +1,123 @@
# coding=utf-8
"""Tests for ast_utils."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pasta
from pasta.augment import errors
from pasta.base import ast_utils
from pasta.base import test_utils
class UtilsTest(test_utils.TestCase):
def test_sanitize_source(self):
coding_lines = (
'# -*- coding: latin-1 -*-',
'# -*- coding: iso-8859-15 -*-',
'# vim: set fileencoding=ascii :',
'# This Python file uses the following encoding: utf-8',
)
src_template = '{coding}\na = 123\n'
sanitized_src = '# (removed coding)\na = 123\n'
for line in coding_lines:
src = src_template.format(coding=line)
# Replaced on lines 1 and 2
self.assertEqual(sanitized_src, ast_utils.sanitize_source(src))
src_prefix = '"""Docstring."""\n'
self.assertEqual(src_prefix + sanitized_src,
ast_utils.sanitize_source(src_prefix + src))
# Unchanged on line 3
src_prefix = '"""Docstring."""\n# line 2\n'
self.assertEqual(src_prefix + src,
ast_utils.sanitize_source(src_prefix + src))
class AlterChildTest(test_utils.TestCase):
def testRemoveChildMethod(self):
src = """\
class C():
def f(x):
return x + 2
def g(x):
return x + 3
"""
tree = pasta.parse(src)
class_node = tree.body[0]
meth1_node = class_node.body[0]
ast_utils.remove_child(class_node, meth1_node)
result = pasta.dump(tree)
expected = """\
class C():
def g(x):
return x + 3
"""
self.assertEqual(result, expected)
def testRemoveAlias(self):
src = "from a import b, c"
tree = pasta.parse(src)
import_node = tree.body[0]
alias1 = import_node.names[0]
ast_utils.remove_child(import_node, alias1)
self.assertEqual(pasta.dump(tree), "from a import c")
def testRemoveFromBlock(self):
src = """\
if a:
print("foo!")
x = 1
"""
tree = pasta.parse(src)
if_block = tree.body[0]
print_stmt = if_block.body[0]
ast_utils.remove_child(if_block, print_stmt)
expected = """\
if a:
x = 1
"""
self.assertEqual(pasta.dump(tree), expected)
def testReplaceChildInBody(self):
src = 'def foo():\n a = 0\n a += 1 # replace this\n return a\n'
replace_with = pasta.parse('foo(a + 1) # trailing comment\n').body[0]
expected = 'def foo():\n a = 0\n foo(a + 1) # replace this\n return a\n'
t = pasta.parse(src)
parent = t.body[0]
node_to_replace = parent.body[1]
ast_utils.replace_child(parent, node_to_replace, replace_with)
self.assertEqual(expected, pasta.dump(t))
def testReplaceChildInvalid(self):
src = 'def foo():\n return 1\nx = 1\n'
replace_with = pasta.parse('bar()').body[0]
t = pasta.parse(src)
parent = t.body[0]
node_to_replace = t.body[1]
with self.assertRaises(errors.InvalidAstError):
ast_utils.replace_child(parent, node_to_replace, replace_with)
@@ -0,0 +1,160 @@
# coding=utf-8
"""Generate code from an annotated syntax tree."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import collections
import six
from pasta.base import annotate
from pasta.base import formatting as fmt
from pasta.base import fstring_utils
class PrintError(Exception):
"""An exception for when we failed to print the tree."""
class Printer(annotate.BaseVisitor):
"""Traverses an AST and generates formatted python source code.
This uses the same base visitor as annotating the AST, but instead of eating a
token it spits one out. For special formatting information which was stored on
the node, this is output exactly as it was read in unless one or more of the
dependency attributes used to generate it has changed, in which case its
default formatting is used.
"""
def __init__(self):
super(Printer, self).__init__()
self.code = ''
def visit(self, node):
node._printer_info = collections.defaultdict(lambda: False)
try:
super(Printer, self).visit(node)
except (TypeError, ValueError, IndexError, KeyError) as e:
raise PrintError(e)
del node._printer_info
def visit_Num(self, node):
self.prefix(node)
content = fmt.get(node, 'content')
self.code += content if content is not None else repr(node.n)
self.suffix(node)
def visit_Str(self, node):
self.prefix(node)
content = fmt.get(node, 'content')
self.code += content if content is not None else repr(node.s)
self.suffix(node)
def visit_JoinedStr(self, node):
self.prefix(node)
content = fmt.get(node, 'content')
if content is None:
parts = []
for val in node.values:
if isinstance(val, ast.Str):
parts.append(val.s)
else:
parts.append(fstring_utils.placeholder(len(parts)))
content = repr(''.join(parts))
values = [to_str(v) for v in fstring_utils.get_formatted_values(node)]
self.code += fstring_utils.perform_replacements(content, values)
self.suffix(node)
def visit_Bytes(self, node):
self.prefix(node)
content = fmt.get(node, 'content')
self.code += content if content is not None else repr(node.s)
self.suffix(node)
def token(self, value):
self.code += value
def optional_token(self, node, attr_name, token_val,
allow_whitespace_prefix=False, default=False):
del allow_whitespace_prefix
value = fmt.get(node, attr_name)
if value is None and default:
value = token_val
self.code += value or ''
def attr(self, node, attr_name, attr_vals, deps=None, default=None):
"""Add the formatted data stored for a given attribute on this node.
If any of the dependent attributes of the node have changed since it was
annotated, then the stored formatted data for this attr_name is no longer
valid, and we must use the default instead.
Arguments:
node: (ast.AST) An AST node to retrieve formatting information from.
attr_name: (string) Name to load the formatting information from.
attr_vals: (list of functions/strings) Unused here.
deps: (optional, set of strings) Attributes of the node which the stored
formatting data depends on.
default: (string) Default formatted data for this attribute.
"""
del attr_vals
if not hasattr(node, '_printer_info') or node._printer_info[attr_name]:
return
node._printer_info[attr_name] = True
val = fmt.get(node, attr_name)
if (val is None or deps and
any(getattr(node, dep, None) != fmt.get(node, dep + '__src')
for dep in deps)):
val = default
self.code += val if val is not None else ''
def check_is_elif(self, node):
try:
return fmt.get(node, 'is_elif')
except AttributeError:
return False
def check_is_continued_try(self, node):
# TODO: Don't set extra attributes on nodes
return getattr(node, 'is_continued', False)
def check_is_continued_with(self, node):
# TODO: Don't set extra attributes on nodes
return getattr(node, 'is_continued', False)
def to_str(tree):
"""Convenient function to get the python source for an AST."""
p = Printer()
# Detect the most prevalent indentation style in the file and use it when
# printing indented nodes which don't have formatting data.
seen_indent_diffs = collections.defaultdict(lambda: 0)
for node in ast.walk(tree):
indent_diff = fmt.get(node, 'indent_diff', '')
if indent_diff:
seen_indent_diffs[indent_diff] += 1
if seen_indent_diffs:
indent_diff, _ = max(six.iteritems(seen_indent_diffs),
key=lambda tup: tup[1] if tup[0] else -1)
p.set_default_indent_diff(indent_diff)
p.visit(tree)
return p.code
@@ -0,0 +1,106 @@
# coding=utf-8
"""Tests for generating code from a non-annotated ast."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import os.path
import unittest
from six import with_metaclass
import pasta
from pasta.base import codegen
from pasta.base import test_utils
TESTDATA_DIR = os.path.realpath(
os.path.join(os.path.dirname(pasta.__file__), '../testdata'))
def _is_syntax_valid(filepath):
with open(filepath, 'r') as f:
try:
ast.parse(f.read())
except SyntaxError:
return False
return True
class AutoFormatTestMeta(type):
def __new__(mcs, name, bases, inst_dict):
# Helper function to generate a test method
def auto_format_test_generator(input_file):
def test(self):
with open(input_file, 'r') as handle:
src = handle.read()
t = ast.parse(src)
auto_formatted = codegen.to_str(t)
self.assertMultiLineEqual(src, auto_formatted)
return test
# Add a test method for each input file
test_method_prefix = 'test_auto_format_'
data_dir = os.path.join(TESTDATA_DIR, 'codegen')
for dirpath, _, files in os.walk(data_dir):
for filename in files:
if filename.endswith('.in'):
full_path = os.path.join(dirpath, filename)
inst_dict[test_method_prefix + filename[:-3]] = unittest.skipIf(
not _is_syntax_valid(full_path),
'Test contains syntax not supported by this version.',
)(auto_format_test_generator(full_path))
return type.__new__(mcs, name, bases, inst_dict)
class AutoFormatTest(with_metaclass(AutoFormatTestMeta, test_utils.TestCase)):
"""Tests that code without formatting info is printed neatly."""
def test_imports(self):
src = 'from a import b\nimport c, d\nfrom ..e import f, g\n'
t = ast.parse(src)
self.assertEqual(src, pasta.dump(t))
@test_utils.requires_features('exec_node')
def test_exec_node_default(self):
src = 'exec foo in bar'
t = ast.parse(src)
self.assertEqual('exec(foo, bar)\n', pasta.dump(t))
@test_utils.requires_features('bytes_node')
def test_bytes(self):
src = "b'foo'"
t = ast.parse(src)
self.assertEqual("b'foo'\n", pasta.dump(t))
def test_default_indentation(self):
for indent in (' ', ' ', '\t'):
src ='def a():\n' + indent + 'b\n'
t = pasta.parse(src)
t.body.extend(ast.parse('def c(): d').body)
self.assertEqual(codegen.to_str(t),
src + 'def c():\n' + indent + 'd\n')
def suite():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(AutoFormatTest))
return result
if __name__ == '__main__':
unittest.main()
@@ -0,0 +1,49 @@
# coding=utf-8
"""Operations for storing and retrieving formatting info on ast nodes."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
PASTA_DICT = '__pasta__'
def get(node, name, default=None):
try:
return _formatting_dict(node).get(name, default)
except AttributeError:
return default
def set(node, name, value):
if not hasattr(node, PASTA_DICT):
try:
setattr(node, PASTA_DICT, {})
except AttributeError:
pass
_formatting_dict(node)[name] = value
def append(node, name, value):
set(node, name, get(node, name, '') + value)
def prepend(node, name, value):
set(node, name, value + get(node, name, ''))
def _formatting_dict(node):
return getattr(node, PASTA_DICT)
@@ -0,0 +1,44 @@
# coding=utf-8
"""Helpers for working with fstrings (python3.6+)."""
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
_FSTRING_VAL_PLACEHOLDER = '__pasta_fstring_val_{index}__'
def get_formatted_values(joined_str):
"""Get all FormattedValues from a JoinedStr, in order."""
return [v for v in joined_str.values if isinstance(v, ast.FormattedValue)]
def placeholder(val_index):
"""Get the placeholder token for a FormattedValue in an fstring."""
return _FSTRING_VAL_PLACEHOLDER.format(index=val_index)
def perform_replacements(fstr, values):
"""Replace placeholders in an fstring with subexpressions."""
for i, value in enumerate(values):
fstr = fstr.replace(_wrap(placeholder(i)), _wrap(value))
return fstr
def _wrap(s):
return '{%s}' % s
@@ -0,0 +1,277 @@
# coding=utf-8
"""Perform static analysis on python syntax trees."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import collections
import six
# TODO: Support relative imports
# Represents a reference to something external to the module.
# Fields:
# name: (string) The full dotted name being referenced.
# node: (ast.AST) The AST node where the reference is defined.
# name_ref: (Name) The name object that refers to the imported name, if
# applicable. This may not be the same id if the import is aliased.
ExternalReference = collections.namedtuple('ExternalReference',
('name', 'node', 'name_ref'))
class ScopeVisitor(ast.NodeVisitor):
def __init__(self):
super(ScopeVisitor, self).__init__()
self._parent = None
self.root_scope = self.scope = RootScope(None)
def visit(self, node):
if node is None:
return
if self.root_scope.node is None:
self.root_scope.node = node
self.root_scope.set_parent(node, self._parent)
tmp = self._parent
self._parent = node
super(ScopeVisitor, self).visit(node)
self._parent = tmp
def visit_in_order(self, node, *attrs):
for attr in attrs:
val = getattr(node, attr, None)
if val is None:
continue
if isinstance(val, list):
for item in val:
self.visit(item)
elif isinstance(val, ast.AST):
self.visit(val)
def visit_Import(self, node):
for alias in node.names:
name_parts = alias.name.split('.')
if not alias.asname:
# If not aliased, define the top-level module of the import
cur_name = self.scope.define_name(name_parts[0], alias)
self.root_scope.add_external_reference(name_parts[0], alias,
name_ref=cur_name)
# Define names of sub-modules imported
partial_name = name_parts[0]
for part in name_parts[1:]:
partial_name += '.' + part
cur_name = cur_name.lookup_name(part)
cur_name.define(alias)
self.root_scope.add_external_reference(partial_name, alias,
name_ref=cur_name)
else:
# If the imported name is aliased, define that name only
name = self.scope.define_name(alias.asname, alias)
# Define names of sub-modules imported
for i in range(1, len(name_parts)):
self.root_scope.add_external_reference('.'.join(name_parts[:i]),
alias)
self.root_scope.add_external_reference(alias.name, alias, name_ref=name)
self.generic_visit(node)
def visit_ImportFrom(self, node):
if node.module:
name_parts = node.module.split('.')
for i in range(1, len(name_parts) + 1):
self.root_scope.add_external_reference('.'.join(name_parts[:i]), node)
for alias in node.names:
name = self.scope.define_name(alias.asname or alias.name, alias)
if node.module:
self.root_scope.add_external_reference(
'.'.join((node.module, alias.name)), alias, name_ref=name)
# TODO: else? relative imports
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, (ast.Store, ast.Param)):
self.scope.define_name(node.id, node)
elif isinstance(node.ctx, ast.Load):
self.scope.lookup_name(node.id).add_reference(node)
self.root_scope.set_name_for_node(node, self.scope.lookup_name(node.id))
self.generic_visit(node)
def visit_FunctionDef(self, node):
# Visit decorator list first to avoid declarations in args
self.visit_in_order(node, 'decorator_list')
if isinstance(self.root_scope.parent(node), ast.ClassDef):
pass # TODO: Support referencing methods by "self" where possible
else:
self.scope.define_name(node.name, node)
try:
self.scope = self.scope.create_scope(node)
self.visit_in_order(node, 'args', 'returns', 'body')
finally:
self.scope = self.scope.parent_scope
def visit_arguments(self, node):
self.visit_in_order(node, 'defaults', 'args')
if six.PY2:
# In python 2.x, these names are not Name nodes. Define them explicitly
# to be able to find references in the function body.
for arg_attr_name in ('vararg', 'kwarg'):
arg_name = getattr(node, arg_attr_name, None)
if arg_name is not None:
self.scope.define_name(arg_name, node)
else:
# Visit defaults first to avoid declarations in args
self.visit_in_order(node, 'vararg', 'kwarg')
def visit_arg(self, node):
self.scope.define_name(node.arg, node)
self.generic_visit(node)
def visit_ClassDef(self, node):
self.visit_in_order(node, 'decorator_list', 'bases')
self.scope.define_name(node.name, node)
try:
self.scope = self.scope.create_scope(node)
self.visit_in_order(node, 'body')
finally:
self.scope = self.scope.parent_scope
def visit_Attribute(self, node):
self.generic_visit(node)
node_value_name = self.root_scope.get_name_for_node(node.value)
if node_value_name:
node_name = node_value_name.lookup_name(node.attr)
self.root_scope.set_name_for_node(node, node_name)
node_name.add_reference(node)
class Scope(object):
def __init__(self, parent_scope, node):
self.parent_scope = parent_scope
self.names = {}
self.node = node
def define_name(self, name, node):
try:
name_obj = self.names[name]
except KeyError:
name_obj = self.names[name] = Name(name)
name_obj.define(node)
return name_obj
def lookup_name(self, name):
try:
return self.names[name]
except KeyError:
pass
if self.parent_scope is None:
name_obj = self.names[name] = Name(name)
return name_obj
return self.parent_scope.lookup_name(name)
def get_root_scope(self):
return self.parent_scope.get_root_scope()
def lookup_scope(self, node):
return self.get_root_scope().lookup_scope(node)
def create_scope(self, node):
subscope = Scope(self, node)
self.get_root_scope()._set_scope_for_node(node, subscope)
return subscope
class RootScope(Scope):
def __init__(self, node):
super(RootScope, self).__init__(None, node)
self.external_references = {}
self._parents = {}
self._nodes_to_names = {}
self._node_scopes = {}
def add_external_reference(self, name, node, name_ref=None):
ref = ExternalReference(name=name, node=node, name_ref=name_ref)
if name in self.external_references:
self.external_references[name].append(ref)
else:
self.external_references[name] = [ref]
def get_root_scope(self):
return self
def parent(self, node):
return self._parents.get(node, None)
def set_parent(self, node, parent):
self._parents[node] = parent
if parent is None:
self._node_scopes[node] = self
def get_name_for_node(self, node):
return self._nodes_to_names.get(node, None)
def set_name_for_node(self, node, name):
self._nodes_to_names[node] = name
def lookup_scope(self, node):
while node:
try:
return self._node_scopes[node]
except KeyError:
node = self.parent(node)
return None
def _set_scope_for_node(self, node, node_scope):
self._node_scopes[node] = node_scope
# Should probably also have a scope?
class Name(object):
def __init__(self, id):
self.id = id
self.definition = None
self.reads = []
self.attrs = {}
def add_reference(self, node):
self.reads.append(node)
def define(self, node):
if self.definition:
self.reads.append(node)
else:
self.definition = node
def lookup_name(self, name):
try:
return self.attrs[name]
except KeyError:
name_obj = self.attrs[name] = Name('.'.join((self.id, name)))
return name_obj
def analyze(tree):
v = ScopeVisitor()
v.visit(tree)
return v.scope
@@ -0,0 +1,467 @@
# coding=utf-8
"""Tests for scope."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import textwrap
import unittest
from pasta.base import ast_utils
from pasta.base import scope
from pasta.base import test_utils
class ScopeTest(test_utils.TestCase):
def test_top_level_imports(self):
self.maxDiff = None
source = textwrap.dedent("""\
import aaa
import bbb, ccc.ddd
import aaa.bbb.ccc
from eee import fff
from ggg.hhh import iii, jjj
""")
tree = ast.parse(source)
nodes = tree.body
node_1_aaa = nodes[0].names[0]
node_2_bbb = nodes[1].names[0]
node_2_ccc_ddd = nodes[1].names[1]
node_3_aaa_bbb_ccc = nodes[2].names[0]
node_4_eee = nodes[3]
node_4_fff = nodes[3].names[0]
node_5_ggg_hhh = nodes[4]
node_5_iii = nodes[4].names[0]
node_5_jjj = nodes[4].names[1]
s = scope.analyze(tree)
self.assertItemsEqual(
s.names.keys(), {
'aaa', 'bbb', 'ccc', 'fff', 'iii', 'jjj'
})
self.assertItemsEqual(
s.external_references.keys(), {
'aaa', 'bbb', 'ccc', 'ccc.ddd', 'aaa.bbb', 'aaa.bbb.ccc', 'eee',
'eee.fff', 'ggg', 'ggg.hhh', 'ggg.hhh.iii', 'ggg.hhh.jjj'
})
self.assertItemsEqual(s.external_references['aaa'], [
scope.ExternalReference('aaa', node_1_aaa, s.names['aaa']),
scope.ExternalReference('aaa', node_3_aaa_bbb_ccc, s.names['aaa']),
])
self.assertItemsEqual(s.external_references['bbb'], [
scope.ExternalReference('bbb', node_2_bbb, s.names['bbb']),
])
self.assertItemsEqual(s.external_references['ccc'], [
scope.ExternalReference('ccc', node_2_ccc_ddd, s.names['ccc']),
])
self.assertItemsEqual(s.external_references['ccc.ddd'], [
scope.ExternalReference('ccc.ddd', node_2_ccc_ddd,
s.names['ccc'].attrs['ddd']),
])
self.assertItemsEqual(s.external_references['aaa.bbb'], [
scope.ExternalReference('aaa.bbb', node_3_aaa_bbb_ccc,
s.names['aaa'].attrs['bbb']),
])
self.assertItemsEqual(s.external_references['aaa.bbb.ccc'], [
scope.ExternalReference('aaa.bbb.ccc', node_3_aaa_bbb_ccc,
s.names['aaa'].attrs['bbb'].attrs['ccc']),
])
self.assertItemsEqual(s.external_references['eee'], [
scope.ExternalReference('eee', node_4_eee, None),
])
self.assertItemsEqual(s.external_references['eee.fff'], [
scope.ExternalReference('eee.fff', node_4_fff, s.names['fff']),
])
self.assertItemsEqual(s.external_references['ggg'], [
scope.ExternalReference('ggg', node_5_ggg_hhh, None),
])
self.assertItemsEqual(s.external_references['ggg.hhh'], [
scope.ExternalReference('ggg.hhh', node_5_ggg_hhh, None),
])
self.assertItemsEqual(s.external_references['ggg.hhh.iii'], [
scope.ExternalReference('ggg.hhh.iii', node_5_iii, s.names['iii']),
])
self.assertItemsEqual(s.external_references['ggg.hhh.jjj'], [
scope.ExternalReference('ggg.hhh.jjj', node_5_jjj, s.names['jjj']),
])
self.assertIs(s.names['aaa'].definition, node_1_aaa)
self.assertIs(s.names['bbb'].definition, node_2_bbb)
self.assertIs(s.names['ccc'].definition, node_2_ccc_ddd)
self.assertIs(s.names['fff'].definition, node_4_fff)
self.assertIs(s.names['iii'].definition, node_5_iii)
self.assertIs(s.names['jjj'].definition, node_5_jjj)
self.assertItemsEqual(s.names['aaa'].reads, [node_3_aaa_bbb_ccc])
for ref in {'bbb', 'ccc', 'fff', 'iii', 'jjj'}:
self.assertEqual(s.names[ref].reads, [], 'Expected no reads for %s' % ref)
def test_if_nested_imports(self):
source = textwrap.dedent("""\
if a:
import aaa
elif b:
import bbb
else:
import ccc
""")
tree = ast.parse(source)
nodes = tree.body
node_aaa, node_bbb, node_ccc = ast_utils.find_nodes_by_type(tree, ast.alias)
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'bbb', 'ccc', 'a', 'b'})
self.assertItemsEqual(s.external_references.keys(), {'aaa', 'bbb', 'ccc'})
self.assertEqual(s.names['aaa'].definition, node_aaa)
self.assertEqual(s.names['bbb'].definition, node_bbb)
self.assertEqual(s.names['ccc'].definition, node_ccc)
self.assertIsNone(s.names['a'].definition)
self.assertIsNone(s.names['b'].definition)
for ref in {'aaa', 'bbb', 'ccc'}:
self.assertEqual(s.names[ref].reads, [],
'Expected no reads for %s' % ref)
def test_try_nested_imports(self):
source = textwrap.dedent("""\
try:
import aaa
except:
import bbb
finally:
import ccc
""")
tree = ast.parse(source)
nodes = tree.body
node_aaa, node_bbb, node_ccc = ast_utils.find_nodes_by_type(tree, ast.alias)
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'bbb', 'ccc'})
self.assertItemsEqual(s.external_references.keys(), {'aaa', 'bbb', 'ccc'})
self.assertEqual(s.names['aaa'].definition, node_aaa)
self.assertEqual(s.names['bbb'].definition, node_bbb)
self.assertEqual(s.names['ccc'].definition, node_ccc)
for ref in {'aaa', 'bbb', 'ccc'}:
self.assertEqual(s.names[ref].reads, [],
'Expected no reads for %s' % ref)
def test_functiondef_nested_imports(self):
source = textwrap.dedent("""\
def foo(bar):
import aaa
""")
tree = ast.parse(source)
nodes = tree.body
node_aaa = ast_utils.find_nodes_by_type(tree, ast.alias)[0]
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'foo'})
self.assertItemsEqual(s.external_references.keys(), {'aaa'})
def test_classdef_nested_imports(self):
source = textwrap.dedent("""\
class Foo():
import aaa
""")
tree = ast.parse(source)
nodes = tree.body
node_aaa = nodes[0].body[0].names[0]
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'Foo'})
self.assertItemsEqual(s.external_references.keys(), {'aaa'})
def test_multilevel_import_reads(self):
source = textwrap.dedent("""\
import aaa.bbb.ccc
aaa.bbb.ccc.foo()
""")
tree = ast.parse(source)
nodes = tree.body
node_ref = nodes[1].value.func.value
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa'})
self.assertItemsEqual(s.external_references.keys(),
{'aaa', 'aaa.bbb', 'aaa.bbb.ccc'})
self.assertItemsEqual(s.names['aaa'].reads, [node_ref.value.value])
self.assertItemsEqual(s.names['aaa'].attrs['bbb'].reads, [node_ref.value])
self.assertItemsEqual(s.names['aaa'].attrs['bbb'].attrs['ccc'].reads,
[node_ref])
def test_import_reads_in_functiondef(self):
source = textwrap.dedent("""\
import aaa
@aaa.x
def foo(bar):
return aaa
""")
tree = ast.parse(source)
nodes = tree.body
return_value = nodes[1].body[0].value
decorator = nodes[1].decorator_list[0].value
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'foo'})
self.assertItemsEqual(s.external_references.keys(), {'aaa'})
self.assertItemsEqual(s.names['aaa'].reads, [decorator, return_value])
def test_import_reads_in_classdef(self):
source = textwrap.dedent("""\
import aaa
@aaa.x
class Foo(aaa.Bar):
pass
""")
tree = ast.parse(source)
nodes = tree.body
node_aaa = nodes[0].names[0]
decorator = nodes[1].decorator_list[0].value
base = nodes[1].bases[0].value
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'Foo'})
self.assertItemsEqual(s.external_references.keys(), {'aaa'})
self.assertItemsEqual(s.names['aaa'].reads, [decorator, base])
def test_import_masked_by_function_arg(self):
source = textwrap.dedent("""\
import aaa
def foo(aaa=aaa):
return aaa
""")
tree = ast.parse(source)
nodes = tree.body
argval = nodes[1].args.defaults[0]
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'foo'})
self.assertItemsEqual(s.external_references.keys(), {'aaa'})
self.assertItemsEqual(s.names['aaa'].reads, [argval])
def test_import_masked_by_assign(self):
source = textwrap.dedent("""\
import aaa
def foo():
aaa = 123
return aaa
aaa
""")
tree = ast.parse(source)
nodes = tree.body
node_aaa = nodes[2].value
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'foo'})
self.assertItemsEqual(s.external_references.keys(), {'aaa'})
self.assertItemsEqual(s.names['aaa'].reads, [node_aaa])
def test_import_in_decortator(self):
source = textwrap.dedent("""\
import aaa
@aaa.wrapper
def foo(aaa=1):
pass
""")
tree = ast.parse(source)
nodes = tree.body
decorator = nodes[1].decorator_list[0].value
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'foo'})
self.assertItemsEqual(s.external_references.keys(), {'aaa'})
self.assertItemsEqual(s.names['aaa'].reads, [decorator])
@test_utils.requires_features('type_annotations')
def test_import_in_return_type(self):
source = textwrap.dedent("""\
import aaa
def foo() -> aaa.Foo:
pass
""")
tree = ast.parse(source)
nodes = tree.body
func = nodes[1]
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'foo'})
self.assertItemsEqual(s.external_references.keys(), {'aaa'})
self.assertItemsEqual(s.names['aaa'].reads, [func.returns.value])
@test_utils.requires_features('type_annotations')
def test_import_in_argument_type(self):
source = textwrap.dedent("""\
import aaa
def foo(bar: aaa.Bar):
pass
""")
tree = ast.parse(source)
nodes = tree.body
func = nodes[1]
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'foo'})
self.assertItemsEqual(s.external_references.keys(), {'aaa'})
self.assertItemsEqual(s.names['aaa'].reads,
[func.args.args[0].annotation.value])
def test_import_attribute_references(self):
source = textwrap.dedent("""\
import aaa.bbb.ccc, ddd.eee
aaa.x()
aaa.bbb.y()
aaa.bbb.ccc.z()
""")
tree = ast.parse(source)
nodes = tree.body
call1 = nodes[1].value.func.value
call2 = nodes[2].value.func.value
call3 = nodes[3].value.func.value
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'ddd'})
self.assertItemsEqual(s.external_references.keys(),
{'aaa', 'aaa.bbb', 'aaa.bbb.ccc', 'ddd', 'ddd.eee'})
self.assertItemsEqual(s.names['aaa'].reads,
[call1, call2.value, call3.value.value])
self.assertItemsEqual(s.names['aaa'].attrs['bbb'].reads,
[call2, call3.value])
self.assertItemsEqual(s.names['aaa'].attrs['bbb'].attrs['ccc'].reads,
[call3])
def test_lookup_scope(self):
src = textwrap.dedent("""\
import a
def b(c, d, e=1):
class F(d):
g = 1
return c
""")
t = ast.parse(src)
import_node, func_node = t.body
class_node, return_node = func_node.body
sc = scope.analyze(t)
import_node_scope = sc.lookup_scope(import_node)
self.assertIs(import_node_scope.node, t)
self.assertIs(import_node_scope, sc)
self.assertItemsEqual(import_node_scope.names, ['a', 'b'])
func_node_scope = sc.lookup_scope(func_node)
self.assertIs(func_node_scope.node, func_node)
self.assertIs(func_node_scope.parent_scope, sc)
self.assertItemsEqual(func_node_scope.names, ['c', 'd', 'e', 'F'])
class_node_scope = sc.lookup_scope(class_node)
self.assertIs(class_node_scope.node, class_node)
self.assertIs(class_node_scope.parent_scope, func_node_scope)
self.assertItemsEqual(class_node_scope.names, ['g'])
return_node_scope = sc.lookup_scope(return_node)
self.assertIs(return_node_scope.node, func_node)
self.assertIs(return_node_scope, func_node_scope)
self.assertItemsEqual(return_node_scope.names, ['c', 'd', 'e', 'F'])
self.assertIs(class_node_scope.lookup_scope(func_node),
func_node_scope)
self.assertIsNone(sc.lookup_scope(ast.Name(id='foo')))
def test_class_methods(self):
source = textwrap.dedent("""\
import aaa
class C:
def aaa(self):
return aaa
def bbb(self):
return aaa
""")
tree = ast.parse(source)
importstmt, classdef = tree.body
method_aaa, method_bbb = classdef.body
s = scope.analyze(tree)
self.assertItemsEqual(s.names.keys(), {'aaa', 'C'})
self.assertItemsEqual(s.external_references.keys(), {'aaa'})
self.assertItemsEqual(s.names['aaa'].reads,
[method_aaa.body[0].value, method_bbb.body[0].value])
# TODO: Test references to C.aaa, C.bbb once supported
def test_vararg_kwarg_references_in_function_body(self):
source = textwrap.dedent("""\
def aaa(bbb, *ccc, **ddd):
ccc
ddd
eee(ccc, ddd)
""")
tree = ast.parse(source)
funcdef, call = tree.body
ccc_expr, ddd_expr = funcdef.body
sc = scope.analyze(tree)
func_scope = sc.lookup_scope(funcdef)
self.assertIn('ccc', func_scope.names)
self.assertItemsEqual(func_scope.names['ccc'].reads, [ccc_expr.value])
self.assertIn('ddd', func_scope.names)
self.assertItemsEqual(func_scope.names['ddd'].reads, [ddd_expr.value])
def suite():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(ScopeTest))
return result
if __name__ == '__main__':
unittest.main()
@@ -0,0 +1,89 @@
# coding=utf-8
"""Useful stuff for tests."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import sys
import unittest
from six.moves import zip
class TestCase(unittest.TestCase):
def checkAstsEqual(self, a, b):
"""Compares two ASTs and fails if there are differences.
Ignores `ctx` fields and formatting info.
"""
if a is None and b is None:
return
try:
self.assertIsNotNone(a)
self.assertIsNotNone(b)
for node_a, node_b in zip(ast.walk(a), ast.walk(b)):
self.assertEqual(type(node_a), type(node_b))
for field in type(node_a)()._fields:
a_val = getattr(node_a, field, None)
b_val = getattr(node_b, field, None)
if isinstance(a_val, list):
for item_a, item_b in zip(a_val, b_val):
self.checkAstsEqual(item_a, item_b)
elif isinstance(a_val, ast.AST) or isinstance(b_val, ast.AST):
if (not isinstance(a_val, (ast.Load, ast.Store, ast.Param)) and
not isinstance(b_val, (ast.Load, ast.Store, ast.Param))):
self.assertIsNotNone(a_val)
self.assertIsNotNone(b_val)
self.checkAstsEqual(a_val, b_val)
else:
self.assertEqual(a_val, b_val)
except AssertionError as ae:
self.fail('ASTs differ:\n%s\n !=\n%s\n\n%s' % (
ast.dump(a), ast.dump(b), ae))
if not hasattr(TestCase, 'assertItemsEqual'):
setattr(TestCase, 'assertItemsEqual', TestCase.assertCountEqual)
def requires_features(*features):
return unittest.skipIf(
any(not supports_feature(feature) for feature in features),
'Tests features which are not supported by this version of python. '
'Missing: %r' % [f for f in features if not supports_feature(f)])
def supports_feature(feature):
if feature == 'bytes_node':
return hasattr(ast, 'Bytes') and issubclass(ast.Bytes, ast.AST)
if feature == 'exec_node':
return hasattr(ast, 'Exec') and issubclass(ast.Exec, ast.AST)
if feature == 'type_annotations':
try:
ast.parse('def foo(bar: str=123) -> None: pass')
except SyntaxError:
return False
return True
if feature == 'fstring':
return hasattr(ast, 'JoinedStr') and issubclass(ast.JoinedStr, ast.AST)
# Python 2 counts tabs as 8 spaces for indentation
if feature == 'mixed_tabs_spaces':
return sys.version_info[0] < 3
return False
@@ -0,0 +1,66 @@
# coding=utf-8
"""Tests for google3.third_party.py.pasta.base.test_utils."""
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import unittest
from pasta.base import test_utils
class CheckAstEqualityTest(test_utils.TestCase):
def test_empty(self):
src = ""
t = ast.parse(src)
self.checkAstsEqual(t, t)
def test_one_global(self):
src = "X = 1\n"
t = ast.parse(src)
self.checkAstsEqual(t, t)
def test_two_globals(self):
src = "X = 1\nY = 2\n"
t = ast.parse(src)
self.checkAstsEqual(t, t)
def test_different_number_of_nodes(self):
src1 = "X = 1\ndef Foo():\n return None\n"
src2 = src1 + "Y = 2\n"
t1 = ast.parse(src1)
t2 = ast.parse(src2)
with self.assertRaises(AssertionError):
self.checkAstsEqual(t1, t2)
def test_simple_function_def(self):
code = ("def foo(x):\n"
" return x + 1\n")
t = ast.parse(code)
self.checkAstsEqual(t, t)
def suite():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(CheckAstEqualityTest))
return result
if __name__ == '__main__':
unittest.main()
@@ -0,0 +1,513 @@
# coding=utf-8
"""Token generator for analyzing source code in logical units.
This module contains the TokenGenerator used for annotating a parsed syntax tree
with source code formatting.
"""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import collections
import contextlib
import itertools
import tokenize
from six import StringIO
from pasta.base import formatting as fmt
from pasta.base import fstring_utils
# Alias for extracting token names
TOKENS = tokenize
Token = collections.namedtuple('Token', ('type', 'src', 'start', 'end', 'line'))
FORMATTING_TOKENS = (TOKENS.INDENT, TOKENS.DEDENT, TOKENS.NL, TOKENS.NEWLINE,
TOKENS.COMMENT)
class TokenGenerator(object):
"""Helper for sequentially parsing Python source code, token by token.
Holds internal state during parsing, including:
_tokens: List of tokens in the source code, as parsed by `tokenize` module.
_parens: Stack of open parenthesis at the current point in parsing.
_hints: Number of open parentheses, brackets, etc. at the current point.
_scope_stack: Stack containing tuples of nodes where the last parenthesis that
was open is related to one of the nodes on the top of the stack.
_lines: Full lines of the source code.
_i: Index of the last token that was parsed. Initially -1.
_loc: (lineno, column_offset) pair of the position in the source that has been
parsed to. This should be either the start or end of the token at index _i.
Arguments:
ignore_error_tokens: If True, will ignore error tokens. Otherwise, an error
token will cause an exception. This is useful when the source being parsed
contains invalid syntax, e.g. if it is in an fstring context.
"""
def __init__(self, source, ignore_error_token=False):
self.lines = source.splitlines(True)
self._tokens = list(_generate_tokens(source, ignore_error_token))
self._parens = []
self._hints = 0
self._scope_stack = []
self._len = len(self._tokens)
self._i = -1
self._loc = self.loc_begin()
def chars_consumed(self):
return len(self._space_between((1, 0), self._tokens[self._i].end))
def loc_begin(self):
"""Get the start column of the current location parsed to."""
if self._i < 0:
return (1, 0)
return self._tokens[self._i].start
def loc_end(self):
"""Get the end column of the current location parsed to."""
if self._i < 0:
return (1, 0)
return self._tokens[self._i].end
def peek(self):
"""Get the next token without advancing."""
if self._i + 1 >= self._len:
return None
return self._tokens[self._i + 1]
def peek_non_whitespace(self):
"""Get the next non-whitespace token without advancing."""
return self.peek_conditional(lambda t: t.type not in FORMATTING_TOKENS)
def peek_conditional(self, condition):
"""Get the next token of the given type without advancing."""
return next((t for t in self._tokens[self._i + 1:] if condition(t)), None)
def next(self, advance=True):
"""Consume the next token and optionally advance the current location."""
self._i += 1
if self._i >= self._len:
return None
if advance:
self._loc = self._tokens[self._i].end
return self._tokens[self._i]
def rewind(self, amount=1):
"""Rewind the token iterator."""
self._i -= amount
def whitespace(self, max_lines=None, comment=False):
"""Parses whitespace from the current _loc to the next non-whitespace.
Arguments:
max_lines: (optional int) Maximum number of lines to consider as part of
the whitespace. Valid values are None, 0 and 1.
comment: (boolean) If True, look for a trailing comment even when not in
a parenthesized scope.
Pre-condition:
`_loc' represents the point before which everything has been parsed and
after which nothing has been parsed.
Post-condition:
`_loc' is exactly at the character that was parsed to.
"""
next_token = self.peek()
if not comment and next_token and next_token.type == TOKENS.COMMENT:
return ''
def predicate(token):
return (token.type in (TOKENS.INDENT, TOKENS.DEDENT) or
token.type == TOKENS.COMMENT and (comment or self._hints) or
token.type == TOKENS.ERRORTOKEN and token.src == ' ' or
max_lines is None and token.type in (TOKENS.NL, TOKENS.NEWLINE))
whitespace = list(self.takewhile(predicate, advance=False))
next_token = self.peek()
result = ''
for tok in itertools.chain(whitespace,
((next_token,) if next_token else ())):
result += self._space_between(self._loc, tok.start)
if tok != next_token:
result += tok.src
self._loc = tok.end
else:
self._loc = tok.start
# Eat a single newline character
if ((max_lines is None or max_lines > 0) and
next_token and next_token.type in (TOKENS.NL, TOKENS.NEWLINE)):
result += self.next().src
return result
def block_whitespace(self, indent_level):
"""Parses whitespace from the current _loc to the end of the block."""
# Get the normal suffix lines, but don't advance the token index unless
# there is no indentation to account for
start_i = self._i
full_whitespace = self.whitespace(comment=True)
if not indent_level:
return full_whitespace
self._i = start_i
# Trim the full whitespace into only lines that match the indentation level
lines = full_whitespace.splitlines(True)
try:
last_line_idx = next(i for i, line in reversed(list(enumerate(lines)))
if line.startswith(indent_level + '#'))
except StopIteration:
# No comment lines at the end of this block
self._loc = self._tokens[self._i].end
return ''
lines = lines[:last_line_idx + 1]
# Advance the current location to the last token in the lines we've read
end_line = self._tokens[self._i].end[0] + 1 + len(lines)
list(self.takewhile(lambda tok: tok.start[0] < end_line))
self._loc = self._tokens[self._i].end
return ''.join(lines)
def dots(self, num_dots):
"""Parse a number of dots.
This is to work around an oddity in python3's tokenizer, which treats three
`.` tokens next to each other in a FromImport's level as an ellipsis. This
parses until the expected number of dots have been seen.
"""
result = ''
dots_seen = 0
prev_loc = self._loc
while dots_seen < num_dots:
tok = self.next()
assert tok.src in ('.', '...')
result += self._space_between(prev_loc, tok.start) + tok.src
dots_seen += tok.src.count('.')
prev_loc = self._loc
return result
def open_scope(self, node, single_paren=False):
"""Open a parenthesized scope on the given node."""
result = ''
parens = []
start_i = self._i
start_loc = prev_loc = self._loc
# Eat whitespace or '(' tokens one at a time
for tok in self.takewhile(
lambda t: t.type in FORMATTING_TOKENS or t.src == '('):
# Stores all the code up to and including this token
result += self._space_between(prev_loc, tok.start)
if tok.src == '(' and single_paren and parens:
self.rewind()
self._loc = tok.start
break
result += tok.src
if tok.src == '(':
# Start a new scope
parens.append(result)
result = ''
start_i = self._i
start_loc = self._loc
prev_loc = self._loc
if parens:
# Add any additional whitespace on to the last open-paren
next_tok = self.peek()
parens[-1] += result + self._space_between(self._loc, next_tok.start)
self._loc = next_tok.start
# Add each paren onto the stack
for paren in parens:
self._parens.append(paren)
self._scope_stack.append(_scope_helper(node))
else:
# No parens were encountered, then reset like this method did nothing
self._i = start_i
self._loc = start_loc
def close_scope(self, node, prefix_attr='prefix', suffix_attr='suffix',
trailing_comma=False, single_paren=False):
"""Close a parenthesized scope on the given node, if one is open."""
# Ensures the prefix + suffix are not None
if fmt.get(node, prefix_attr) is None:
fmt.set(node, prefix_attr, '')
if fmt.get(node, suffix_attr) is None:
fmt.set(node, suffix_attr, '')
if not self._parens or node not in self._scope_stack[-1]:
return
symbols = {')'}
if trailing_comma:
symbols.add(',')
parsed_to_i = self._i
parsed_to_loc = prev_loc = self._loc
encountered_paren = False
result = ''
for tok in self.takewhile(
lambda t: t.type in FORMATTING_TOKENS or t.src in symbols):
# Consume all space up to this token
result += self._space_between(prev_loc, tok.start)
if tok.src == ')' and single_paren and encountered_paren:
self.rewind()
parsed_to_i = self._i
parsed_to_loc = tok.start
fmt.append(node, suffix_attr, result)
break
# Consume the token itself
result += tok.src
if tok.src == ')':
# Close out the open scope
encountered_paren = True
self._scope_stack.pop()
fmt.prepend(node, prefix_attr, self._parens.pop())
fmt.append(node, suffix_attr, result)
result = ''
parsed_to_i = self._i
parsed_to_loc = tok.end
if not self._parens or node not in self._scope_stack[-1]:
break
prev_loc = tok.end
# Reset back to the last place where we parsed anything
self._i = parsed_to_i
self._loc = parsed_to_loc
def hint_open(self):
"""Indicates opening a group of parentheses or brackets."""
self._hints += 1
def hint_closed(self):
"""Indicates closing a group of parentheses or brackets."""
self._hints -= 1
if self._hints < 0:
raise ValueError('Hint value negative')
@contextlib.contextmanager
def scope(self, node, attr=None, trailing_comma=False):
"""Context manager to handle a parenthesized scope."""
self.open_scope(node, single_paren=(attr is not None))
yield
if attr:
self.close_scope(node, prefix_attr=attr + '_prefix',
suffix_attr=attr + '_suffix',
trailing_comma=trailing_comma,
single_paren=True)
else:
self.close_scope(node, trailing_comma=trailing_comma)
def is_in_scope(self):
"""Return True iff there is a scope open."""
return self._parens or self._hints
def str(self):
"""Parse a full string literal from the input."""
def predicate(token):
return (token.type in (TOKENS.STRING, TOKENS.COMMENT) or
self.is_in_scope() and token.type in (TOKENS.NL, TOKENS.NEWLINE))
return self.eat_tokens(predicate)
def eat_tokens(self, predicate):
"""Parse input from tokens while a given condition is met."""
content = ''
prev_loc = self._loc
tok = None
for tok in self.takewhile(predicate, advance=False):
content += self._space_between(prev_loc, tok.start)
content += tok.src
prev_loc = tok.end
if tok:
self._loc = tok.end
return content
def fstr(self):
"""Parses an fstring, including subexpressions.
Returns:
A generator function which, when repeatedly reads a chunk of the fstring
up until the next subexpression and yields that chunk, plus a new token
generator to use to parse the subexpression. The subexpressions in the
original fstring data are replaced by placeholders to make it possible to
fill them in with new values, if desired.
"""
def fstr_parser():
# Reads the whole fstring as a string, then parses it char by char
if self.peek_non_whitespace().type == TOKENS.STRING:
# Normal fstrings are one ore more STRING tokens, maybe mixed with
# spaces, e.g.: f"Hello, {name}"
str_content = self.str()
else:
# Format specifiers in fstrings are also JoinedStr nodes, but these are
# arbitrary expressions, e.g. in: f"{value:{width}.{precision}}", the
# format specifier is an fstring: "{width}.{precision}" but these are
# not STRING tokens.
def fstr_eater(tok):
if tok.type == TOKENS.OP and tok.src == '}':
if fstr_eater.level <= 0:
return False
fstr_eater.level -= 1
if tok.type == TOKENS.OP and tok.src == '{':
fstr_eater.level += 1
return True
fstr_eater.level = 0
str_content = self.eat_tokens(fstr_eater)
indexed_chars = enumerate(str_content)
val_idx = 0
i = -1
result = ''
while i < len(str_content) - 1:
i, c = next(indexed_chars)
result += c
# When an open bracket is encountered, start parsing a subexpression
if c == '{':
# First check if this is part of an escape sequence
# (f"{{" is used to escape a bracket literal)
nexti, nextc = next(indexed_chars)
if nextc == '{':
result += c
continue
indexed_chars = itertools.chain([(nexti, nextc)], indexed_chars)
# Add a placeholder onto the result
result += fstring_utils.placeholder(val_idx) + '}'
val_idx += 1
# Yield a new token generator to parse the subexpression only
tg = TokenGenerator(str_content[i+1:], ignore_error_token=True)
yield (result, tg)
result = ''
# Skip the number of characters consumed by the subexpression
for tg_i in range(tg.chars_consumed()):
i, c = next(indexed_chars)
# Eat up to and including the close bracket
i, c = next(indexed_chars)
while c != '}':
i, c = next(indexed_chars)
# Yield the rest of the fstring, when done
yield (result, None)
return fstr_parser
def _space_between(self, start_loc, end_loc):
"""Parse the space between a location and the next token"""
if start_loc > end_loc:
raise ValueError('start_loc > end_loc', start_loc, end_loc)
if start_loc[0] > len(self.lines):
return ''
prev_row, prev_col = start_loc
end_row, end_col = end_loc
if prev_row == end_row:
return self.lines[prev_row - 1][prev_col:end_col]
return ''.join(itertools.chain(
(self.lines[prev_row - 1][prev_col:],),
self.lines[prev_row:end_row - 1],
(self.lines[end_row - 1][:end_col],) if end_col > 0 else '',
))
def next_name(self):
"""Parse the next name token."""
last_i = self._i
def predicate(token):
return token.type != TOKENS.NAME
unused_tokens = list(self.takewhile(predicate, advance=False))
result = self.next(advance=False)
self._i = last_i
return result
def next_of_type(self, token_type):
"""Parse a token of the given type and return it."""
token = self.next()
if token.type != token_type:
raise ValueError("Expected %r but found %r\nline %d: %s" % (
tokenize.tok_name[token_type], token.src, token.start[0],
self.lines[token.start[0] - 1]))
return token
def takewhile(self, condition, advance=True):
"""Parse tokens as long as a condition holds on the next token."""
prev_loc = self._loc
token = self.next(advance=advance)
while token is not None and condition(token):
yield token
prev_loc = self._loc
token = self.next(advance=advance)
self.rewind()
self._loc = prev_loc
def _scope_helper(node):
"""Get the closure of nodes that could begin a scope at this point.
For instance, when encountering a `(` when parsing a BinOp node, this could
indicate that the BinOp itself is parenthesized OR that the BinOp's left node
could be parenthesized.
E.g.: (a + b * c) or (a + b) * c or (a) + b * c
^ ^ ^
Arguments:
node: (ast.AST) Node encountered when opening a scope.
Returns:
A closure of nodes which that scope might apply to.
"""
if isinstance(node, ast.Attribute):
return (node,) + _scope_helper(node.value)
if isinstance(node, ast.Subscript):
return (node,) + _scope_helper(node.value)
if isinstance(node, ast.Assign):
return (node,) + _scope_helper(node.targets[0])
if isinstance(node, ast.AugAssign):
return (node,) + _scope_helper(node.target)
if isinstance(node, ast.Expr):
return (node,) + _scope_helper(node.value)
if isinstance(node, ast.Compare):
return (node,) + _scope_helper(node.left)
if isinstance(node, ast.BoolOp):
return (node,) + _scope_helper(node.values[0])
if isinstance(node, ast.BinOp):
return (node,) + _scope_helper(node.left)
if isinstance(node, ast.Tuple) and node.elts:
return (node,) + _scope_helper(node.elts[0])
if isinstance(node, ast.Call):
return (node,) + _scope_helper(node.func)
if isinstance(node, ast.GeneratorExp):
return (node,) + _scope_helper(node.elt)
if isinstance(node, ast.IfExp):
return (node,) + _scope_helper(node.body)
return (node,)
def _generate_tokens(source, ignore_error_token=False):
token_generator = tokenize.generate_tokens(StringIO(source).readline)
try:
for tok in token_generator:
yield Token(*tok)
except tokenize.TokenError:
if not ignore_error_token:
raise