| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import collections |
| | import importlib.util |
| | import os |
| | import re |
| | from pathlib import Path |
| |
|
| |
|
| | PATH_TO_TRANSFORMERS = "src/transformers" |
| |
|
| |
|
| | |
| | _re_backend = re.compile(r"is\_([a-z_]*)_available()") |
| | |
| | _re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}") |
| | |
| | _re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]') |
| | |
| | _re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)") |
| | |
| | _re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)') |
| | |
| | _re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]") |
| | |
| | _re_quote_object = re.compile('^\s+"([^"]+)",') |
| | |
| | _re_between_brackets = re.compile("^\s+\[([^\]]+)\]") |
| | |
| | _re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") |
| | |
| | _re_try = re.compile(r"^\s*try:") |
| | |
| | _re_else = re.compile(r"^\s*else:") |
| |
|
| |
|
| | def find_backend(line): |
| | """Find one (or multiple) backend in a code line of the init.""" |
| | if _re_test_backend.search(line) is None: |
| | return None |
| | backends = [b[0] for b in _re_backend.findall(line)] |
| | backends.sort() |
| | return "_and_".join(backends) |
| |
|
| |
|
| | def parse_init(init_file): |
| | """ |
| | Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects |
| | defined |
| | """ |
| | with open(init_file, "r", encoding="utf-8", newline="\n") as f: |
| | lines = f.readlines() |
| |
|
| | line_index = 0 |
| | while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"): |
| | line_index += 1 |
| |
|
| | |
| | if line_index >= len(lines): |
| | return None |
| |
|
| | |
| | objects = [] |
| | while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None: |
| | line = lines[line_index] |
| | |
| | if _re_one_line_import_struct.search(line): |
| | content = _re_one_line_import_struct.search(line).groups()[0] |
| | imports = re.findall("\[([^\]]+)\]", content) |
| | for imp in imports: |
| | objects.extend([obj[1:-1] for obj in imp.split(", ")]) |
| | line_index += 1 |
| | continue |
| | single_line_import_search = _re_import_struct_key_value.search(line) |
| | if single_line_import_search is not None: |
| | imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0] |
| | objects.extend(imports) |
| | elif line.startswith(" " * 8 + '"'): |
| | objects.append(line[9:-3]) |
| | line_index += 1 |
| |
|
| | import_dict_objects = {"none": objects} |
| | |
| | while not lines[line_index].startswith("if TYPE_CHECKING"): |
| | |
| | backend = find_backend(lines[line_index]) |
| | |
| | if _re_try.search(lines[line_index - 1]) is None: |
| | backend = None |
| |
|
| | if backend is not None: |
| | line_index += 1 |
| |
|
| | |
| | while _re_else.search(lines[line_index]) is None: |
| | line_index += 1 |
| |
|
| | line_index += 1 |
| |
|
| | objects = [] |
| | |
| | while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4): |
| | line = lines[line_index] |
| | if _re_import_struct_add_one.search(line) is not None: |
| | objects.append(_re_import_struct_add_one.search(line).groups()[0]) |
| | elif _re_import_struct_add_many.search(line) is not None: |
| | imports = _re_import_struct_add_many.search(line).groups()[0].split(", ") |
| | imports = [obj[1:-1] for obj in imports if len(obj) > 0] |
| | objects.extend(imports) |
| | elif _re_between_brackets.search(line) is not None: |
| | imports = _re_between_brackets.search(line).groups()[0].split(", ") |
| | imports = [obj[1:-1] for obj in imports if len(obj) > 0] |
| | objects.extend(imports) |
| | elif _re_quote_object.search(line) is not None: |
| | objects.append(_re_quote_object.search(line).groups()[0]) |
| | elif line.startswith(" " * 8 + '"'): |
| | objects.append(line[9:-3]) |
| | elif line.startswith(" " * 12 + '"'): |
| | objects.append(line[13:-3]) |
| | line_index += 1 |
| |
|
| | import_dict_objects[backend] = objects |
| | else: |
| | line_index += 1 |
| |
|
| | |
| | objects = [] |
| | while ( |
| | line_index < len(lines) |
| | and find_backend(lines[line_index]) is None |
| | and not lines[line_index].startswith("else") |
| | ): |
| | line = lines[line_index] |
| | single_line_import_search = _re_import.search(line) |
| | if single_line_import_search is not None: |
| | objects.extend(single_line_import_search.groups()[0].split(", ")) |
| | elif line.startswith(" " * 8): |
| | objects.append(line[8:-2]) |
| | line_index += 1 |
| |
|
| | type_hint_objects = {"none": objects} |
| | |
| | while line_index < len(lines): |
| | |
| | backend = find_backend(lines[line_index]) |
| | |
| | if _re_try.search(lines[line_index - 1]) is None: |
| | backend = None |
| |
|
| | if backend is not None: |
| | line_index += 1 |
| |
|
| | |
| | while _re_else.search(lines[line_index]) is None: |
| | line_index += 1 |
| |
|
| | line_index += 1 |
| |
|
| | objects = [] |
| | |
| | while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): |
| | line = lines[line_index] |
| | single_line_import_search = _re_import.search(line) |
| | if single_line_import_search is not None: |
| | objects.extend(single_line_import_search.groups()[0].split(", ")) |
| | elif line.startswith(" " * 12): |
| | objects.append(line[12:-2]) |
| | line_index += 1 |
| |
|
| | type_hint_objects[backend] = objects |
| | else: |
| | line_index += 1 |
| |
|
| | return import_dict_objects, type_hint_objects |
| |
|
| |
|
| | def analyze_results(import_dict_objects, type_hint_objects): |
| | """ |
| | Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init. |
| | """ |
| |
|
| | def find_duplicates(seq): |
| | return [k for k, v in collections.Counter(seq).items() if v > 1] |
| |
|
| | if list(import_dict_objects.keys()) != list(type_hint_objects.keys()): |
| | return ["Both sides of the init do not have the same backends!"] |
| |
|
| | errors = [] |
| | for key in import_dict_objects.keys(): |
| | duplicate_imports = find_duplicates(import_dict_objects[key]) |
| | if duplicate_imports: |
| | errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}") |
| | duplicate_type_hints = find_duplicates(type_hint_objects[key]) |
| | if duplicate_type_hints: |
| | errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}") |
| |
|
| | if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])): |
| | name = "base imports" if key == "none" else f"{key} backend" |
| | errors.append(f"Differences for {name}:") |
| | for a in type_hint_objects[key]: |
| | if a not in import_dict_objects[key]: |
| | errors.append(f" {a} in TYPE_HINT but not in _import_structure.") |
| | for a in import_dict_objects[key]: |
| | if a not in type_hint_objects[key]: |
| | errors.append(f" {a} in _import_structure but not in TYPE_HINT.") |
| | return errors |
| |
|
| |
|
| | def check_all_inits(): |
| | """ |
| | Check all inits in the transformers repo and raise an error if at least one does not define the same objects in |
| | both halves. |
| | """ |
| | failures = [] |
| | for root, _, files in os.walk(PATH_TO_TRANSFORMERS): |
| | if "__init__.py" in files: |
| | fname = os.path.join(root, "__init__.py") |
| | objects = parse_init(fname) |
| | if objects is not None: |
| | errors = analyze_results(*objects) |
| | if len(errors) > 0: |
| | errors[0] = f"Problem in {fname}, both halves do not define the same objects.\n{errors[0]}" |
| | failures.append("\n".join(errors)) |
| | if len(failures) > 0: |
| | raise ValueError("\n\n".join(failures)) |
| |
|
| |
|
| | def get_transformers_submodules(): |
| | """ |
| | Returns the list of Transformers submodules. |
| | """ |
| | submodules = [] |
| | for path, directories, files in os.walk(PATH_TO_TRANSFORMERS): |
| | for folder in directories: |
| | |
| | if folder.startswith("_"): |
| | directories.remove(folder) |
| | continue |
| | |
| | if len(list((Path(path) / folder).glob("*.py"))) == 0: |
| | continue |
| | short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS)) |
| | submodule = short_path.replace(os.path.sep, ".") |
| | submodules.append(submodule) |
| | for fname in files: |
| | if fname == "__init__.py": |
| | continue |
| | short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS)) |
| | submodule = short_path.replace(".py", "").replace(os.path.sep, ".") |
| | if len(submodule.split(".")) == 1: |
| | submodules.append(submodule) |
| | return submodules |
| |
|
| |
|
| | IGNORE_SUBMODULES = [ |
| | "convert_pytorch_checkpoint_to_tf2", |
| | "modeling_flax_pytorch_utils", |
| | ] |
| |
|
| |
|
| | def check_submodules(): |
| | |
| | spec = importlib.util.spec_from_file_location( |
| | "transformers", |
| | os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), |
| | submodule_search_locations=[PATH_TO_TRANSFORMERS], |
| | ) |
| | transformers = spec.loader.load_module() |
| |
|
| | module_not_registered = [ |
| | module |
| | for module in get_transformers_submodules() |
| | if module not in IGNORE_SUBMODULES and module not in transformers._import_structure.keys() |
| | ] |
| | if len(module_not_registered) > 0: |
| | list_of_modules = "\n".join(f"- {module}" for module in module_not_registered) |
| | raise ValueError( |
| | "The following submodules are not properly registered in the main init of Transformers:\n" |
| | f"{list_of_modules}\n" |
| | "Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value." |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | check_all_inits() |
| | check_submodules() |
| |
|