# Copyright 2021 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import difflib import os import re import subprocess import textwrap from argparse import ArgumentParser, Namespace from datetime import date from pathlib import Path from typing import Any, Callable, Optional, Union from ..models.auto.configuration_auto import CONFIG_MAPPING_NAMES, MODEL_NAMES_MAPPING from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES from ..models.auto.video_processing_auto import VIDEO_PROCESSOR_MAPPING_NAMES from ..utils import is_libcst_available from . import BaseTransformersCLICommand from .add_fast_image_processor import add_fast_image_processor # We protect this import to avoid requiring it for all `transformers` CLI commands - however it is actually # strictly required for this one (we need it both for modular and for the following Visitor) if is_libcst_available(): import libcst as cst from libcst import CSTVisitor from libcst import matchers as m class ClassFinder(CSTVisitor): """ A visitor to find all classes in a python module. """ def __init__(self): self.classes: list = [] self.public_classes: list = [] self.is_in_class = False def visit_ClassDef(self, node: cst.ClassDef) -> None: """Record class names. We assume classes always only appear at top-level (i.e. no class definition in function or similar)""" self.classes.append(node.name.value) self.is_in_class = True def leave_ClassDef(self, node: cst.ClassDef): self.is_in_class = False def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine): """Record all public classes inside the `__all__` assignment.""" simple_top_level_assign_structure = m.SimpleStatementLine( body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] ) if not self.is_in_class and m.matches(node, simple_top_level_assign_structure): assigned_variable = node.body[0].targets[0].target.value if assigned_variable == "__all__": elements = node.body[0].value.elements self.public_classes = [element.value.value for element in elements] CURRENT_YEAR = date.today().year TRANSFORMERS_PATH = Path(__file__).parent.parent REPO_PATH = TRANSFORMERS_PATH.parent.parent COPYRIGHT = f""" # coding=utf-8 # Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """.lstrip() class ModelInfos: """ Retrieve the basic information about an existing model classes. """ def __init__(self, lowercase_name: str): # Just to make sure it's indeed lowercase self.lowercase_name = lowercase_name.lower().replace(" ", "_").replace("-", "_") if self.lowercase_name not in CONFIG_MAPPING_NAMES: self.lowercase_name.replace("_", "-") if self.lowercase_name not in CONFIG_MAPPING_NAMES: raise ValueError(f"{lowercase_name} is not a valid model name") self.paper_name = MODEL_NAMES_MAPPING[self.lowercase_name] self.config_class = CONFIG_MAPPING_NAMES[self.lowercase_name] self.camelcase_name = self.config_class.replace("Config", "") # Get tokenizer class if self.lowercase_name in TOKENIZER_MAPPING_NAMES: self.tokenizer_class, self.fast_tokenizer_class = TOKENIZER_MAPPING_NAMES[self.lowercase_name] self.fast_tokenizer_class = ( None if self.fast_tokenizer_class == "PreTrainedTokenizerFast" else self.fast_tokenizer_class ) else: self.tokenizer_class, self.fast_tokenizer_class = None, None self.image_processor_class, self.fast_image_processor_class = IMAGE_PROCESSOR_MAPPING_NAMES.get( self.lowercase_name, (None, None) ) self.video_processor_class = VIDEO_PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None) self.feature_extractor_class = FEATURE_EXTRACTOR_MAPPING_NAMES.get(self.lowercase_name, None) self.processor_class = PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None) def add_content_to_file(file_name: Union[str, os.PathLike], new_content: str, add_after: str): """ A utility to add some content inside a given file. Args: file_name (`str` or `os.PathLike`): The name of the file in which we want to insert some content. new_content (`str`): The content to add. add_after (`str`): The new content is added just after the first instance matching it. """ with open(file_name, "r", encoding="utf-8") as f: old_content = f.read() before, after = old_content.split(add_after, 1) new_content = before + add_after + new_content + after with open(file_name, "w", encoding="utf-8") as f: f.write(new_content) def add_model_to_auto_mappings( old_model_infos: ModelInfos, new_lowercase_name: str, new_model_paper_name: str, filenames_to_add: list[tuple[str, bool]], ): """ Add a model to all the relevant mappings in the auto module. Args: old_model_infos (`ModelInfos`): The structure containing the class information of the old model. new_lowercase_name (`str`): The new lowercase model name. new_model_paper_name (`str`): The fully cased name (as in the official paper name) of the new model. filenames_to_add (`list[tuple[str, bool]]`): A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...] """ new_cased_name = "".join(x.title() for x in new_lowercase_name.replace("-", "_").split("_")) old_lowercase_name = old_model_infos.lowercase_name old_cased_name = old_model_infos.camelcase_name filenames_to_add = [ (filename.replace(old_lowercase_name, "auto"), to_add) for filename, to_add in filenames_to_add[1:] ] # fast tokenizer/image processor have the same auto mappings as normal ones corrected_filenames_to_add = [] for file, to_add in filenames_to_add: if re.search(r"(?:tokenization)|(?:image_processing)_auto_fast.py", file): previous_file, previous_to_add = corrected_filenames_to_add[-1] corrected_filenames_to_add[-1] = (previous_file, previous_to_add or to_add) else: corrected_filenames_to_add.append((file, to_add)) # Add the config mappings directly as the handling for config is a bit different add_content_to_file( TRANSFORMERS_PATH / "models" / "auto" / "configuration_auto.py", new_content=f' ("{new_lowercase_name}", "{new_cased_name}Config"),\n', add_after="CONFIG_MAPPING_NAMES = OrderedDict[str, str](\n [\n # Add configs here\n", ) add_content_to_file( TRANSFORMERS_PATH / "models" / "auto" / "configuration_auto.py", new_content=f' ("{new_lowercase_name}", "{new_model_paper_name}"),\n', add_after="MODEL_NAMES_MAPPING = OrderedDict[str, str](\n [\n # Add full (and cased) model names here\n", ) for filename, to_add in corrected_filenames_to_add: if to_add: # The auto mapping filename = filename.replace("_fast.py", ".py") with open(TRANSFORMERS_PATH / "models" / "auto" / filename) as f: file = f.read() # The regex has to be a bit complex like this as the tokenizer mapping has new lines everywhere matching_lines = re.findall( rf'( {{8,12}}\(\s*"{old_lowercase_name}",.*?\),\n)(?: {{4,12}}\(|\])', file, re.DOTALL ) for match in matching_lines: add_content_to_file( TRANSFORMERS_PATH / "models" / "auto" / filename, new_content=match.replace(old_lowercase_name, new_lowercase_name).replace( old_cased_name, new_cased_name ), add_after=match, ) def create_doc_file(new_paper_name: str, public_classes: list[str]): """ Create a new doc file to fill for the new model. Args: new_paper_name (`str`): The fully cased name (as in the official paper name) of the new model. public_classes (`list[str]`): A list of all the public classes that the model will have in the library. """ added_note = ( "\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that " "may not be rendered properly in your Markdown viewer.\n\n-->\n\n" ) copyright_for_markdown = re.sub(r"# ?", "", COPYRIGHT).replace("coding=utf-8\n", "