Import tensorflow

This commit is contained in:
2026-02-15 21:45:42 -08:00
parent f3e8b90764
commit c530630153
20524 changed files with 9017694 additions and 25 deletions
@@ -0,0 +1,23 @@
# coding=utf-8
"""Errors that can occur during augmentation."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class InvalidAstError(Exception):
"""Occurs when the syntax tree does not meet some expected condition."""
@@ -0,0 +1,217 @@
# coding=utf-8
"""Functions for dealing with import statements."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import copy
import logging
from pasta.augment import errors
from pasta.base import ast_utils
from pasta.base import scope
def add_import(tree, name_to_import, asname=None, from_import=True, merge_from_imports=True):
"""Adds an import to the module.
This function will try to ensure not to create duplicate imports. If name_to_import is
already imported, it will return the existing import. This is true even if asname is set
(asname will be ignored, and the existing name will be returned).
If the import would create a name that already exists in the scope given by tree, this
function will "import as", and append "_x" to the asname where x is the smallest positive
integer generating a unique name.
Arguments:
tree: (ast.Module) Module AST to modify.
name_to_import: (string) The absolute name to import.
asname: (string) The alias for the import ("import name_to_import as asname")
from_import: (boolean) If True, import the name using an ImportFrom node.
merge_from_imports: (boolean) If True, merge a newly inserted ImportFrom
node into an existing ImportFrom node, if applicable.
Returns:
The name (as a string) that can be used to reference the imported name. This
can be the fully-qualified name, the basename, or an alias name.
"""
sc = scope.analyze(tree)
# Don't add anything if it's already imported
if name_to_import in sc.external_references:
existing_ref = next((ref for ref in sc.external_references[name_to_import]
if ref.name_ref is not None), None)
if existing_ref:
return existing_ref.name_ref.id
import_node = None
added_name = None
def make_safe_alias_node(alias_name, asname):
# Try to avoid name conflicts
new_alias = ast.alias(name=alias_name, asname=asname)
imported_name = asname or alias_name
counter = 0
while imported_name in sc.names:
counter += 1
imported_name = new_alias.asname = '%s_%d' % (asname or alias_name,
counter)
return new_alias
# Add an ImportFrom node if requested and possible
if from_import and '.' in name_to_import:
from_module, alias_name = name_to_import.rsplit('.', 1)
new_alias = make_safe_alias_node(alias_name, asname)
if merge_from_imports:
# Try to add to an existing ImportFrom from the same module
existing_from_import = next(
(node for node in tree.body if isinstance(node, ast.ImportFrom)
and node.module == from_module and node.level == 0), None)
if existing_from_import:
existing_from_import.names.append(new_alias)
return new_alias.asname or new_alias.name
# Create a new node for this import
import_node = ast.ImportFrom(module=from_module, names=[new_alias], level=0)
# If not already created as an ImportFrom, create a normal Import node
if not import_node:
new_alias = make_safe_alias_node(alias_name=name_to_import, asname=asname)
import_node = ast.Import(
names=[new_alias])
# Insert the node at the top of the module and return the name in scope
tree.body.insert(1 if ast_utils.has_docstring(tree) else 0, import_node)
return new_alias.asname or new_alias.name
def split_import(sc, node, alias_to_remove):
"""Split an import node by moving the given imported alias into a new import.
Arguments:
sc: (scope.Scope) Scope computed on whole tree of the code being modified.
node: (ast.Import|ast.ImportFrom) An import node to split.
alias_to_remove: (ast.alias) The import alias node to remove. This must be a
child of the given `node` argument.
Raises:
errors.InvalidAstError: if `node` is not appropriately contained in the tree
represented by the scope `sc`.
"""
parent = sc.parent(node)
parent_list = None
for a in ('body', 'orelse', 'finalbody'):
if hasattr(parent, a) and node in getattr(parent, a):
parent_list = getattr(parent, a)
break
else:
raise errors.InvalidAstError('Unable to find list containing import %r on '
'parent node %r' % (node, parent))
idx = parent_list.index(node)
new_import = copy.deepcopy(node)
new_import.names = [alias_to_remove]
node.names.remove(alias_to_remove)
parent_list.insert(idx + 1, new_import)
return new_import
def get_unused_import_aliases(tree, sc=None):
"""Get the import aliases that aren't used.
Arguments:
tree: (ast.AST) An ast to find imports in.
sc: A scope.Scope representing tree (generated from scratch if not
provided).
Returns:
A list of ast.alias representing imported aliases that aren't referenced in
the given tree.
"""
if sc is None:
sc = scope.analyze(tree)
unused_aliases = set()
for node in ast.walk(tree):
if isinstance(node, ast.alias):
str_name = node.asname if node.asname is not None else node.name
if str_name in sc.names:
name = sc.names[str_name]
if not name.reads:
unused_aliases.add(node)
else:
# This happens because of https://github.com/google/pasta/issues/32
logging.warning('Imported name %s not found in scope (perhaps it\'s '
'imported dynamically)', str_name)
return unused_aliases
def remove_import_alias_node(sc, node):
"""Remove an alias and if applicable remove their entire import.
Arguments:
sc: (scope.Scope) Scope computed on whole tree of the code being modified.
node: (ast.Import|ast.ImportFrom|ast.alias) The node to remove.
"""
import_node = sc.parent(node)
if len(import_node.names) == 1:
import_parent = sc.parent(import_node)
ast_utils.remove_child(import_parent, import_node)
else:
ast_utils.remove_child(import_node, node)
def remove_duplicates(tree, sc=None):
"""Remove duplicate imports, where it is safe to do so.
This does NOT remove imports that create new aliases
Arguments:
tree: (ast.AST) An ast to modify imports in.
sc: A scope.Scope representing tree (generated from scratch if not
provided).
Returns:
Whether any changes were made.
"""
if sc is None:
sc = scope.analyze(tree)
modified = False
seen_names = set()
for node in tree.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
for alias in list(node.names):
import_node = sc.parent(alias)
if isinstance(import_node, ast.Import):
full_name = alias.name
elif import_node.module:
full_name = '%s%s.%s' % ('.' * import_node.level,
import_node.module, alias.name)
else:
full_name = '%s%s' % ('.' * import_node.level, alias.name)
full_name += ':' + (alias.asname or alias.name)
if full_name in seen_names:
remove_import_alias_node(sc, alias)
modified = True
else:
seen_names.add(full_name)
return modified
@@ -0,0 +1,428 @@
# coding=utf-8
"""Tests for import_utils."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import traceback
import unittest
import pasta
from pasta.augment import import_utils
from pasta.base import ast_utils
from pasta.base import test_utils
from pasta.base import scope
class SplitImportTest(test_utils.TestCase):
def test_split_normal_import(self):
src = 'import aaa, bbb, ccc\n'
t = ast.parse(src)
import_node = t.body[0]
sc = scope.analyze(t)
import_utils.split_import(sc, import_node, import_node.names[1])
self.assertEqual(2, len(t.body))
self.assertEqual(ast.Import, type(t.body[1]))
self.assertEqual([alias.name for alias in t.body[0].names], ['aaa', 'ccc'])
self.assertEqual([alias.name for alias in t.body[1].names], ['bbb'])
def test_split_from_import(self):
src = 'from aaa import bbb, ccc, ddd\n'
t = ast.parse(src)
import_node = t.body[0]
sc = scope.analyze(t)
import_utils.split_import(sc, import_node, import_node.names[1])
self.assertEqual(2, len(t.body))
self.assertEqual(ast.ImportFrom, type(t.body[1]))
self.assertEqual(t.body[0].module, 'aaa')
self.assertEqual(t.body[1].module, 'aaa')
self.assertEqual([alias.name for alias in t.body[0].names], ['bbb', 'ddd'])
def test_split_imports_with_alias(self):
src = 'import aaa as a, bbb as b, ccc as c\n'
t = ast.parse(src)
import_node = t.body[0]
sc = scope.analyze(t)
import_utils.split_import(sc, import_node, import_node.names[1])
self.assertEqual(2, len(t.body))
self.assertEqual([alias.name for alias in t.body[0].names], ['aaa', 'ccc'])
self.assertEqual([alias.name for alias in t.body[1].names], ['bbb'])
self.assertEqual(t.body[1].names[0].asname, 'b')
def test_split_imports_multiple(self):
src = 'import aaa, bbb, ccc\n'
t = ast.parse(src)
import_node = t.body[0]
alias_bbb = import_node.names[1]
alias_ccc = import_node.names[2]
sc = scope.analyze(t)
import_utils.split_import(sc, import_node, alias_bbb)
import_utils.split_import(sc, import_node, alias_ccc)
self.assertEqual(3, len(t.body))
self.assertEqual([alias.name for alias in t.body[0].names], ['aaa'])
self.assertEqual([alias.name for alias in t.body[1].names], ['ccc'])
self.assertEqual([alias.name for alias in t.body[2].names], ['bbb'])
def test_split_nested_imports(self):
test_cases = (
'def foo():\n {import_stmt}\n',
'class Foo(object):\n {import_stmt}\n',
'if foo:\n {import_stmt}\nelse:\n pass\n',
'if foo:\n pass\nelse:\n {import_stmt}\n',
'if foo:\n pass\nelif bar:\n {import_stmt}\n',
'try:\n {import_stmt}\nexcept:\n pass\n',
'try:\n pass\nexcept:\n {import_stmt}\n',
'try:\n pass\nfinally:\n {import_stmt}\n',
'for i in foo:\n {import_stmt}\n',
'for i in foo:\n pass\nelse:\n {import_stmt}\n',
'while foo:\n {import_stmt}\n',
)
for template in test_cases:
try:
src = template.format(import_stmt='import aaa, bbb, ccc')
t = ast.parse(src)
sc = scope.analyze(t)
import_node = ast_utils.find_nodes_by_type(t, ast.Import)[0]
import_utils.split_import(sc, import_node, import_node.names[1])
split_import_nodes = ast_utils.find_nodes_by_type(t, ast.Import)
self.assertEqual(1, len(t.body))
self.assertEqual(2, len(split_import_nodes))
self.assertEqual([alias.name for alias in split_import_nodes[0].names],
['aaa', 'ccc'])
self.assertEqual([alias.name for alias in split_import_nodes[1].names],
['bbb'])
except:
self.fail('Failed while executing case:\n%s\nCaused by:\n%s' %
(src, traceback.format_exc()))
class GetUnusedImportsTest(test_utils.TestCase):
def test_normal_imports(self):
src = """\
import a
import b
a.foo()
"""
tree = ast.parse(src)
self.assertItemsEqual(import_utils.get_unused_import_aliases(tree),
[tree.body[1].names[0]])
def test_import_from(self):
src = """\
from my_module import a
import b
from my_module import c
b.foo()
c.bar()
"""
tree = ast.parse(src)
self.assertItemsEqual(import_utils.get_unused_import_aliases(tree),
[tree.body[0].names[0]])
def test_import_from_alias(self):
src = """\
from my_module import a, b
b.foo()
"""
tree = ast.parse(src)
self.assertItemsEqual(import_utils.get_unused_import_aliases(tree),
[tree.body[0].names[0]])
def test_import_asname(self):
src = """\
from my_module import a as a_mod, b as unused_b_mod
import c as c_mod, d as unused_d_mod
a_mod.foo()
c_mod.foo()
"""
tree = ast.parse(src)
self.assertItemsEqual(import_utils.get_unused_import_aliases(tree),
[tree.body[0].names[1],
tree.body[1].names[1]])
def test_dynamic_import(self):
# For now we just don't want to error out on these, longer
# term we want to do the right thing (see
# https://github.com/google/pasta/issues/32)
src = """\
def foo():
import bar
"""
tree = ast.parse(src)
self.assertItemsEqual(import_utils.get_unused_import_aliases(tree),
[])
class RemoveImportTest(test_utils.TestCase):
# Note that we don't test any 'asname' examples but as far as remove_import_alias_node
# is concerned its not a different case because its still just an alias type
# and we don't care about the internals of the alias we're trying to remove.
def test_remove_just_alias(self):
src = "import a, b"
tree = ast.parse(src)
sc = scope.analyze(tree)
unused_b_node = tree.body[0].names[1]
import_utils.remove_import_alias_node(sc, unused_b_node)
self.assertEqual(len(tree.body), 1)
self.assertEqual(type(tree.body[0]), ast.Import)
self.assertEqual(len(tree.body[0].names), 1)
self.assertEqual(tree.body[0].names[0].name, 'a')
def test_remove_just_alias_import_from(self):
src = "from m import a, b"
tree = ast.parse(src)
sc = scope.analyze(tree)
unused_b_node = tree.body[0].names[1]
import_utils.remove_import_alias_node(sc, unused_b_node)
self.assertEqual(len(tree.body), 1)
self.assertEqual(type(tree.body[0]), ast.ImportFrom)
self.assertEqual(len(tree.body[0].names), 1)
self.assertEqual(tree.body[0].names[0].name, 'a')
def test_remove_full_import(self):
src = "import a"
tree = ast.parse(src)
sc = scope.analyze(tree)
a_node = tree.body[0].names[0]
import_utils.remove_import_alias_node(sc, a_node)
self.assertEqual(len(tree.body), 0)
def test_remove_full_importfrom(self):
src = "from m import a"
tree = ast.parse(src)
sc = scope.analyze(tree)
a_node = tree.body[0].names[0]
import_utils.remove_import_alias_node(sc, a_node)
self.assertEqual(len(tree.body), 0)
class AddImportTest(test_utils.TestCase):
def test_add_normal_import(self):
tree = ast.parse('')
self.assertEqual('a.b.c',
import_utils.add_import(tree, 'a.b.c', from_import=False))
self.assertEqual('import a.b.c\n', pasta.dump(tree))
def test_add_normal_import_with_asname(self):
tree = ast.parse('')
self.assertEqual(
'd',
import_utils.add_import(tree, 'a.b.c', asname='d', from_import=False)
)
self.assertEqual('import a.b.c as d\n', pasta.dump(tree))
def test_add_from_import(self):
tree = ast.parse('')
self.assertEqual('c',
import_utils.add_import(tree, 'a.b.c', from_import=True))
self.assertEqual('from a.b import c\n', pasta.dump(tree))
def test_add_from_import_with_asname(self):
tree = ast.parse('')
self.assertEqual(
'd',
import_utils.add_import(tree, 'a.b.c', asname='d', from_import=True)
)
self.assertEqual('from a.b import c as d\n', pasta.dump(tree))
def test_add_single_name_from_import(self):
tree = ast.parse('')
self.assertEqual('foo',
import_utils.add_import(tree, 'foo', from_import=True))
self.assertEqual('import foo\n', pasta.dump(tree))
def test_add_single_name_from_import_with_asname(self):
tree = ast.parse('')
self.assertEqual(
'bar',
import_utils.add_import(tree, 'foo', asname='bar', from_import=True)
)
self.assertEqual('import foo as bar\n', pasta.dump(tree))
def test_add_existing_import(self):
tree = ast.parse('from a.b import c')
self.assertEqual('c', import_utils.add_import(tree, 'a.b.c'))
self.assertEqual('from a.b import c\n', pasta.dump(tree))
def test_add_existing_import_aliased(self):
tree = ast.parse('from a.b import c as d')
self.assertEqual('d', import_utils.add_import(tree, 'a.b.c'))
self.assertEqual('from a.b import c as d\n', pasta.dump(tree))
def test_add_existing_import_aliased_with_asname(self):
tree = ast.parse('from a.b import c as d')
self.assertEqual('d', import_utils.add_import(tree, 'a.b.c', asname='e'))
self.assertEqual('from a.b import c as d\n', pasta.dump(tree))
def test_add_existing_import_normal_import(self):
tree = ast.parse('import a.b.c')
self.assertEqual('a.b',
import_utils.add_import(tree, 'a.b', from_import=False))
self.assertEqual('import a.b.c\n', pasta.dump(tree))
def test_add_existing_import_normal_import_aliased(self):
tree = ast.parse('import a.b.c as d')
self.assertEqual('a.b',
import_utils.add_import(tree, 'a.b', from_import=False))
self.assertEqual('d',
import_utils.add_import(tree, 'a.b.c', from_import=False))
self.assertEqual('import a.b\nimport a.b.c as d\n', pasta.dump(tree))
def test_add_import_with_conflict(self):
tree = ast.parse('def c(): pass\n')
self.assertEqual('c_1',
import_utils.add_import(tree, 'a.b.c', from_import=True))
self.assertEqual(
'from a.b import c as c_1\ndef c():\n pass\n', pasta.dump(tree))
def test_add_import_with_asname_with_conflict(self):
tree = ast.parse('def c(): pass\n')
self.assertEqual('c_1',
import_utils.add_import(tree, 'a.b', asname='c', from_import=True))
self.assertEqual(
'from a import b as c_1\ndef c():\n pass\n', pasta.dump(tree))
def test_merge_from_import(self):
tree = ast.parse('from a.b import c')
# x is explicitly not merged
self.assertEqual('x', import_utils.add_import(tree, 'a.b.x',
merge_from_imports=False))
self.assertEqual('from a.b import x\nfrom a.b import c\n',
pasta.dump(tree))
# y is allowed to be merged and is grouped into the first matching import
self.assertEqual('y', import_utils.add_import(tree, 'a.b.y',
merge_from_imports=True))
self.assertEqual('from a.b import x, y\nfrom a.b import c\n',
pasta.dump(tree))
def test_add_import_after_docstring(self):
tree = ast.parse('\'Docstring.\'')
self.assertEqual('a', import_utils.add_import(tree, 'a'))
self.assertEqual('\'Docstring.\'\nimport a\n', pasta.dump(tree))
class RemoveDuplicatesTest(test_utils.TestCase):
def test_remove_duplicates(self):
src = """
import a
import b
import c
import b
import d
"""
tree = ast.parse(src)
self.assertTrue(import_utils.remove_duplicates(tree))
self.assertEqual(len(tree.body), 4)
self.assertEqual(tree.body[0].names[0].name, 'a')
self.assertEqual(tree.body[1].names[0].name, 'b')
self.assertEqual(tree.body[2].names[0].name, 'c')
self.assertEqual(tree.body[3].names[0].name, 'd')
def test_remove_duplicates_multiple(self):
src = """
import a, b
import b, c
import d, a, e, f
"""
tree = ast.parse(src)
self.assertTrue(import_utils.remove_duplicates(tree))
self.assertEqual(len(tree.body), 3)
self.assertEqual(len(tree.body[0].names), 2)
self.assertEqual(tree.body[0].names[0].name, 'a')
self.assertEqual(tree.body[0].names[1].name, 'b')
self.assertEqual(len(tree.body[1].names), 1)
self.assertEqual(tree.body[1].names[0].name, 'c')
self.assertEqual(len(tree.body[2].names), 3)
self.assertEqual(tree.body[2].names[0].name, 'd')
self.assertEqual(tree.body[2].names[1].name, 'e')
self.assertEqual(tree.body[2].names[2].name, 'f')
def test_remove_duplicates_empty_node(self):
src = """
import a, b, c
import b, c
"""
tree = ast.parse(src)
self.assertTrue(import_utils.remove_duplicates(tree))
self.assertEqual(len(tree.body), 1)
self.assertEqual(len(tree.body[0].names), 3)
self.assertEqual(tree.body[0].names[0].name, 'a')
self.assertEqual(tree.body[0].names[1].name, 'b')
self.assertEqual(tree.body[0].names[2].name, 'c')
def test_remove_duplicates_normal_and_from(self):
src = """
import a.b
from a import b
"""
tree = ast.parse(src)
self.assertFalse(import_utils.remove_duplicates(tree))
self.assertEqual(len(tree.body), 2)
def test_remove_duplicates_aliases(self):
src = """
import a
import a as ax
import a as ax2
import a as ax
"""
tree = ast.parse(src)
self.assertTrue(import_utils.remove_duplicates(tree))
self.assertEqual(len(tree.body), 3)
self.assertEqual(tree.body[0].names[0].asname, None)
self.assertEqual(tree.body[1].names[0].asname, 'ax')
self.assertEqual(tree.body[2].names[0].asname, 'ax2')
def suite():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(SplitImportTest))
result.addTests(unittest.makeSuite(GetUnusedImportsTest))
result.addTests(unittest.makeSuite(RemoveImportTest))
result.addTests(unittest.makeSuite(AddImportTest))
result.addTests(unittest.makeSuite(RemoveDuplicatesTest))
return result
if __name__ == '__main__':
unittest.main()
@@ -0,0 +1,65 @@
# coding=utf-8
"""Inline constants in a python module."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import copy
from pasta.base import ast_utils
from pasta.base import scope
class InlineError(Exception):
pass
def inline_name(t, name):
"""Inline a constant name into a module."""
sc = scope.analyze(t)
name_node = sc.names[name]
# The name must be a Name node (not a FunctionDef, etc.)
if not isinstance(name_node.definition, ast.Name):
raise InlineError('%r is not a constant; it has type %r' % (
name, type(name_node.definition)))
assign_node = sc.parent(name_node.definition)
if not isinstance(assign_node, ast.Assign):
raise InlineError('%r is not declared in an assignment' % name)
value = assign_node.value
if not isinstance(sc.parent(assign_node), ast.Module):
raise InlineError('%r is not a top-level name' % name)
# If the name is written anywhere else in this module, it is not constant
for ref in name_node.reads:
if isinstance(getattr(ref, 'ctx', None), ast.Store):
raise InlineError('%r is not a constant' % name)
# Replace all reads of the name with a copy of its value
for ref in name_node.reads:
ast_utils.replace_child(sc.parent(ref), ref, copy.deepcopy(value))
# Remove the assignment to this name
if len(assign_node.targets) == 1:
ast_utils.remove_child(sc.parent(assign_node), assign_node)
else:
tgt_list = [tgt for tgt in assign_node.targets
if not (isinstance(tgt, ast.Name) and tgt.id == name)]
assign_node.targets = tgt_list
@@ -0,0 +1,97 @@
# coding=utf-8
"""Tests for augment.inline."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import textwrap
import unittest
from pasta.augment import inline
from pasta.base import test_utils
class InlineTest(test_utils.TestCase):
def test_inline_simple(self):
src = 'x = 1\na = x\n'
t = ast.parse(src)
inline.inline_name(t, 'x')
self.checkAstsEqual(t, ast.parse('a = 1\n'))
def test_inline_multiple_targets(self):
src = 'x = y = z = 1\na = x + y\n'
t = ast.parse(src)
inline.inline_name(t, 'y')
self.checkAstsEqual(t, ast.parse('x = z = 1\na = x + 1\n'))
def test_inline_multiple_reads(self):
src = textwrap.dedent('''\
CONSTANT = "foo"
def a(b=CONSTANT):
return b == CONSTANT
''')
expected = textwrap.dedent('''\
def a(b="foo"):
return b == "foo"
''')
t = ast.parse(src)
inline.inline_name(t, 'CONSTANT')
self.checkAstsEqual(t, ast.parse(expected))
def test_inline_non_constant_fails(self):
src = textwrap.dedent('''\
NOT_A_CONSTANT = "foo"
NOT_A_CONSTANT += "bar"
''')
t = ast.parse(src)
with self.assertRaisesRegexp(inline.InlineError,
'\'NOT_A_CONSTANT\' is not a constant'):
inline.inline_name(t, 'NOT_A_CONSTANT')
def test_inline_function_fails(self):
src = 'def func(): pass\nfunc()\n'
t = ast.parse(src)
with self.assertRaisesRegexp(
inline.InlineError,
'\'func\' is not a constant; it has type %r' % ast.FunctionDef):
inline.inline_name(t, 'func')
def test_inline_conditional_fails(self):
src = 'if define:\n x = 1\na = x\n'
t = ast.parse(src)
with self.assertRaisesRegexp(inline.InlineError,
'\'x\' is not a top-level name'):
inline.inline_name(t, 'x')
def test_inline_non_assign_fails(self):
src = 'CONSTANT1, CONSTANT2 = values'
t = ast.parse(src)
with self.assertRaisesRegexp(
inline.InlineError, '\'CONSTANT1\' is not declared in an assignment'):
inline.inline_name(t, 'CONSTANT1')
def suite():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(InlineTest))
return result
if __name__ == '__main__':
unittest.main()
@@ -0,0 +1,154 @@
# coding=utf-8
"""Rename names in a python module."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import six
from pasta.augment import import_utils
from pasta.base import ast_utils
from pasta.base import scope
def rename_external(t, old_name, new_name):
"""Rename an imported name in a module.
This will rewrite all import statements in `tree` that reference the old
module as well as any names in `tree` which reference the imported name. This
may introduce new import statements, but only if necessary.
For example, to move and rename the module `foo.bar.utils` to `foo.bar_utils`:
> rename_external(tree, 'foo.bar.utils', 'foo.bar_utils')
- import foo.bar.utils
+ import foo.bar_utils
- from foo.bar import utils
+ from foo import bar_utils
- from foo.bar import logic, utils
+ from foo.bar import logic
+ from foo import bar_utils
Arguments:
t: (ast.Module) Module syntax tree to perform the rename in. This will be
updated as a result of this function call with all affected nodes changed
and potentially new Import/ImportFrom nodes added.
old_name: (string) Fully-qualified path of the name to replace.
new_name: (string) Fully-qualified path of the name to update to.
Returns:
True if any changes were made, False otherwise.
"""
sc = scope.analyze(t)
if old_name not in sc.external_references:
return False
has_changed = False
renames = {}
already_changed = []
for ref in sc.external_references[old_name]:
if isinstance(ref.node, ast.alias):
parent = sc.parent(ref.node)
# An alias may be the most specific reference to an imported name, but it
# could if it is a child of an ImportFrom, the ImportFrom node's module
# may also need to be updated.
if isinstance(parent, ast.ImportFrom) and parent not in already_changed:
assert _rename_name_in_importfrom(sc, parent, old_name, new_name)
renames[old_name.rsplit('.', 1)[-1]] = new_name.rsplit('.', 1)[-1]
already_changed.append(parent)
else:
ref.node.name = new_name + ref.node.name[len(old_name):]
if not ref.node.asname:
renames[old_name] = new_name
has_changed = True
elif isinstance(ref.node, ast.ImportFrom):
if ref.node not in already_changed:
assert _rename_name_in_importfrom(sc, ref.node, old_name, new_name)
renames[old_name.rsplit('.', 1)[-1]] = new_name.rsplit('.', 1)[-1]
already_changed.append(ref.node)
has_changed = True
for rename_old, rename_new in six.iteritems(renames):
_rename_reads(sc, t, rename_old, rename_new)
return has_changed
def _rename_name_in_importfrom(sc, node, old_name, new_name):
if old_name == new_name:
return False
module_parts = node.module.split('.')
old_parts = old_name.split('.')
new_parts = new_name.split('.')
# If just the module is changing, rename it
if module_parts[:len(old_parts)] == old_parts:
node.module = '.'.join(new_parts + module_parts[len(old_parts):])
return True
# Find the alias node to be changed
for alias_to_change in node.names:
if alias_to_change.name == old_parts[-1]:
break
else:
return False
alias_to_change.name = new_parts[-1]
# Split the import if the package has changed
if module_parts != new_parts[:-1]:
if len(node.names) > 1:
new_import = import_utils.split_import(sc, node, alias_to_change)
new_import.module = '.'.join(new_parts[:-1])
else:
node.module = '.'.join(new_parts[:-1])
return True
def _rename_reads(sc, t, old_name, new_name):
"""Updates all locations in the module where the given name is read.
Arguments:
sc: (scope.Scope) Scope to work in. This should be the scope of `t`.
t: (ast.AST) The AST to perform updates in.
old_name: (string) Dotted name to update.
new_name: (string) Dotted name to replace it with.
Returns:
True if any changes were made, False otherwise.
"""
name_parts = old_name.split('.')
try:
name = sc.names[name_parts[0]]
for part in name_parts[1:]:
name = name.attrs[part]
except KeyError:
return False
has_changed = False
for ref_node in name.reads:
if isinstance(ref_node, (ast.Name, ast.Attribute)):
ast_utils.replace_child(sc.parent(ref_node), ref_node,
ast.parse(new_name).body[0].value)
has_changed = True
return has_changed
@@ -0,0 +1,119 @@
# coding=utf-8
"""Tests for augment.rename."""
# Copyright 2017 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import unittest
from pasta.augment import rename
from pasta.base import scope
from pasta.base import test_utils
class RenameTest(test_utils.TestCase):
def test_rename_external_in_import(self):
src = 'import aaa.bbb.ccc\naaa.bbb.ccc.foo()'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('import xxx.yyy.ccc\nxxx.yyy.ccc.foo()'))
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb.ccc', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('import xxx.yyy\nxxx.yyy.foo()'))
t = ast.parse(src)
self.assertFalse(rename.rename_external(t, 'bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse(src))
def test_rename_external_in_import_with_asname(self):
src = 'import aaa.bbb.ccc as ddd\nddd.foo()'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('import xxx.yyy.ccc as ddd\nddd.foo()'))
def test_rename_external_in_import_multiple_aliases(self):
src = 'import aaa, aaa.bbb, aaa.bbb.ccc'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('import aaa, xxx.yyy, xxx.yyy.ccc'))
def test_rename_external_in_importfrom(self):
src = 'from aaa.bbb.ccc import ddd\nddd.foo()'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('from xxx.yyy.ccc import ddd\nddd.foo()'))
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb.ccc', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('from xxx.yyy import ddd\nddd.foo()'))
t = ast.parse(src)
self.assertFalse(rename.rename_external(t, 'bbb', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse(src))
def test_rename_external_in_importfrom_alias(self):
src = 'from aaa.bbb import ccc\nccc.foo()'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb.ccc', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('from xxx import yyy\nyyy.foo()'))
def test_rename_external_in_importfrom_alias_with_asname(self):
src = 'from aaa.bbb import ccc as abc\nabc.foo()'
t = ast.parse(src)
self.assertTrue(rename.rename_external(t, 'aaa.bbb.ccc', 'xxx.yyy'))
self.checkAstsEqual(t, ast.parse('from xxx import yyy as abc\nabc.foo()'))
def test_rename_reads_name(self):
src = 'aaa.bbb()'
t = ast.parse(src)
sc = scope.analyze(t)
self.assertTrue(rename._rename_reads(sc, t, 'aaa', 'xxx'))
self.checkAstsEqual(t, ast.parse('xxx.bbb()'))
def test_rename_reads_name_as_attribute(self):
src = 'aaa.bbb()'
t = ast.parse(src)
sc = scope.analyze(t)
rename._rename_reads(sc, t, 'aaa', 'xxx.yyy')
self.checkAstsEqual(t, ast.parse('xxx.yyy.bbb()'))
def test_rename_reads_attribute(self):
src = 'aaa.bbb.ccc()'
t = ast.parse(src)
sc = scope.analyze(t)
rename._rename_reads(sc, t, 'aaa.bbb', 'xxx.yyy')
self.checkAstsEqual(t, ast.parse('xxx.yyy.ccc()'))
def test_rename_reads_noop(self):
src = 'aaa.bbb.ccc()'
t = ast.parse(src)
sc = scope.analyze(t)
rename._rename_reads(sc, t, 'aaa.bbb.ccc.ddd', 'xxx.yyy')
rename._rename_reads(sc, t, 'bbb.aaa', 'xxx.yyy')
self.checkAstsEqual(t, ast.parse(src))
def suite():
result = unittest.TestSuite()
result.addTests(unittest.makeSuite(RenameTest))
return result
if __name__ == '__main__':
unittest.main()