# Copyright (C) Internet Systems Consortium, Inc. ("ISC") # # SPDX-License-Identifier: MPL-2.0 # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, you can obtain one at https://mozilla.org/MPL/2.0/. # # See the COPYRIGHT file distributed with this work for additional # information regarding copyright ownership. from pylint.checkers import BaseChecker import astroid class DnsExplicitImportsChecker(BaseChecker): name = "dns-explicit-imports" msgs = { "W9001": ( "Bare 'import dns' is discouraged; import required submodules explicitly", "dns-bare-import", "Emitted when the package root 'dns' is imported directly.", ), "W9002": ( "Missing explicit import for '%s' (add `import %s`)", "dns-missing-submodule-import", "Emitted when code references dns.<...> but the corresponding module prefix " "was not imported with `import dns.<...>`.", ), "W9003": ( "Unused explicit import for '%s' (remove `import %s`)", "dns-unused-submodule-import", "Emitted when a dns.<...> module is imported explicitly but not used.", ), } def __init__(self, linter=None): super().__init__(linter) self._imported = {} self._imported_aliases = set() self._required = {} def visit_module(self, node): # pylint: disable=unused-argument self._imported = {} self._imported_aliases = set() self._required = {} def leave_module(self, node): # pylint: disable=unused-argument for mod, use_node in sorted(self._required.items()): if mod in self._imported: continue prefix = mod + "." if any(name.startswith(prefix) for name in self._imported): continue self.add_message( "dns-missing-submodule-import", node=use_node, args=(mod, mod), ) for mod, import_node in sorted(self._imported.items()): if mod in self._imported_aliases: continue if any( name == mod or name.startswith(mod + ".") for name in self._required ): continue self.add_message( "dns-unused-submodule-import", node=import_node, args=(mod, mod), ) def visit_import(self, node): for name, _asname in node.names: if name == "dns": self.add_message("dns-bare-import", node=node) continue if name.startswith("dns."): self._imported.setdefault(name, node) if _asname: self._imported_aliases.add(name) def visit_importfrom(self, node): # pylint: disable=unused-argument return def visit_attribute(self, node): parent = node.parent # For `dns.a.b.c`, astroid visits intermediate attributes too. # Process only the rightmost node to avoid duplicate bookkeeping. if isinstance(parent, astroid.nodes.Attribute) and parent.expr is node: return mod = self._dns_module_for_attribute(node) if mod is None: return self._required.setdefault(mod, node) @staticmethod def _dns_attribute_nodes(node): """ Return the chain of Attribute nodes as a list. For `dns.a.b.c`, return the list of Attribute nodes for `dns.a`, `dns.a.b`, and `dns.a.b.c`. Return None if the chain is not rooted in `dns`. """ if not isinstance(node, astroid.nodes.Attribute): return None nodes = [] expr = node while isinstance(expr, astroid.nodes.Attribute): nodes.append(expr) expr = expr.expr if not isinstance(expr, astroid.nodes.Name) or expr.name != "dns": return None return list(reversed(nodes)) @classmethod def _dns_module_for_attribute(cls, node): """ For dns.a.b.c, return the longest dns.a.b... prefix that is likely to be a module, or None if the chain is not rooted in dns. """ last_module = None chain_nodes = cls._dns_attribute_nodes(node) if chain_nodes is None: return None full = "dns." + ".".join(part.attrname for part in chain_nodes) # Prefer inferred module names to avoid treating classes/constants as # modules (e.g. `dns.name.NameRelation` should resolve to `dns.name`). for chain_node in chain_nodes: inferred = cls._infer_module_name(chain_node) if inferred is not None and full.startswith(inferred): last_module = inferred if last_module is not None: return last_module # Fallback when inference is unavailable: assume the terminal segment # is not a module symbol and require the parent path. parts = full.split(".") if len(parts) <= 2: return full return ".".join(parts[:-1]) @staticmethod def _infer_module_name(node): """Infer `dns.` for a node; return None if inference is unsure.""" try: for inferred in node.infer(): if inferred is astroid.util.Uninferable: continue # Inference can return either a Module node directly or another # symbol rooted in a module; normalize both to module name. module = ( inferred if isinstance(inferred, astroid.nodes.Module) else inferred.root() ) name = module.name if name.startswith("dns."): return name # Inference can fail for dynamic/partial code; fall back gracefully. except astroid.AstroidError: pass return None def register(linter): linter.register_checker(DnsExplicitImportsChecker(linter))