# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re from argparse import ArgumentParser, Namespace from datetime import date from pathlib import Path from ..utils import logging from . import BaseTransformersCLICommand logger = logging.get_logger(__name__) # pylint: disable=invalid-name CURRENT_YEAR = date.today().year TRANSFORMERS_PATH = Path(__file__).parent.parent REPO_PATH = TRANSFORMERS_PATH.parent.parent def add_fast_image_processor_to_model_init( fast_image_processing_module_file: str, fast_image_processor_name, model_name: str ): """ Add the fast image processor to the __init__.py file of the model. """ with open(TRANSFORMERS_PATH / "models" / model_name / "__init__.py", "r", encoding="utf-8") as f: content = f.read() fast_image_processing_module_file = fast_image_processing_module_file.split(os.sep)[-1].replace(".py", "") if "import *" in content: # we have an init file in the updated format # get the indented block after if TYPE_CHECKING: and before else:, append the new import, sort the imports and write the updated content # Step 1: Find the block block_regex = re.compile( r"if TYPE_CHECKING:\n(?P.*?)(?=\s*else:)", re.DOTALL, ) match = block_regex.search(content) if not match: raise ValueError("Couldn't find the 'if TYPE_CHECKING' block.") block_content = match.group("if_block") # The captured import block # Step 2: Parse existing entries entries = block_content.split("\n") indent = " " * (len(entries[0]) - len(entries[0].lstrip())) new_entry = f"{indent}from .{fast_image_processing_module_file} import *" if new_entry not in entries: entries.append(new_entry) entries.sort() updated_block = "\n".join(entry for entry in entries) # Replace the original block in the content updated_content = content[: match.start("if_block")] + updated_block + content[match.end("if_block") :] else: # we have an init file in the old format # add "is_torchvision_available" import to from ...utils import ( # Regex to match import statements from transformers.utils pattern = r""" from\s+\.\.\.utils\s+import\s+ (?: # Non-capturing group for either: ([\w, ]+) # 1. Single-line imports (e.g., 'a, b') | # OR \((.*?)\) # 2. Multi-line imports (e.g., '(a, ... b)') ) """ regex = re.compile(pattern, re.VERBOSE | re.DOTALL) def replacement_function(match): # Extract existing imports imports = (match.group(1) or match.group(2)).split(",") imports = imports[:-1] if imports[-1] == "\n" else imports imports = [imp.strip() for imp in imports] # Add the new import if not already present if "is_torchvision_available" not in imports: imports.append("is_torchvision_available") imports.sort() # Convert to multi-line import in all cases updated_imports = "(\n " + ",\n ".join(imports) + ",\n)" return f"from ...utils import {updated_imports}" # Replace all matches in the file content updated_content = regex.sub(replacement_function, content) vision_import_structure_block = f' _import_structure["{fast_image_processing_module_file[:-5]}"] = ["{fast_image_processor_name[:-4]}"]\n' added_import_structure_block = ( "try:\n if not is_torchvision_available():\n" " raise OptionalDependencyNotAvailable()\n" "except OptionalDependencyNotAvailable:\n" " pass\n" "else:\n" f' _import_structure["{fast_image_processing_module_file}"] = ["{fast_image_processor_name}"]\n' ) if vision_import_structure_block not in updated_content: raise ValueError("Couldn't find the 'vision _import_structure block' block.") if added_import_structure_block not in updated_content: updated_content = updated_content.replace( vision_import_structure_block, vision_import_structure_block + "\n" + added_import_structure_block ) vision_import_statement_block = ( f" from .{fast_image_processing_module_file[:-5]} import {fast_image_processor_name[:-4]}\n" ) added_import_statement_block = ( " try:\n if not is_torchvision_available():\n" " raise OptionalDependencyNotAvailable()\n" " except OptionalDependencyNotAvailable:\n" " pass\n" " else:\n" f" from .{fast_image_processing_module_file} import {fast_image_processor_name}\n" ) if vision_import_statement_block not in updated_content: raise ValueError("Couldn't find the 'vision _import_structure block' block.") if added_import_statement_block not in updated_content: updated_content = updated_content.replace( vision_import_statement_block, vision_import_statement_block + "\n" + added_import_statement_block ) # write the updated content with open(TRANSFORMERS_PATH / "models" / model_name / "__init__.py", "w", encoding="utf-8") as f: f.write(updated_content) def add_fast_image_processor_to_auto(image_processor_name: str, fast_image_processor_name: str): """ Add the fast image processor to the auto module. """ with open(TRANSFORMERS_PATH / "models" / "auto" / "image_processing_auto.py", "r", encoding="utf-8") as f: content = f.read() # get all lines containing the image processor name updated_content = content.replace( f'("{image_processor_name}",)', f'("{image_processor_name}", "{fast_image_processor_name}")' ) # write the updated content with open(TRANSFORMERS_PATH / "models" / "auto" / "image_processing_auto.py", "w", encoding="utf-8") as f: f.write(updated_content) def add_fast_image_processor_to_doc(fast_image_processor_name: str, model_name: str): """ Add the fast image processor to the model's doc file. """ doc_source = REPO_PATH / "docs" / "source" # find the doc files doc_files = list(doc_source.glob(f"*/model_doc/{model_name}.md")) if not doc_files: # try again with "-" doc_files = list(doc_source.glob(f"*/model_doc/{model_name.replace('_', '-')}.md")) if not doc_files: raise ValueError(f"No doc files found for {model_name}") base_doc_string = ( f"## {fast_image_processor_name[:-4]}\n\n[[autodoc]] {fast_image_processor_name[:-4]}\n - preprocess" ) fast_doc_string = f"## {fast_image_processor_name}\n\n[[autodoc]] {fast_image_processor_name}\n - preprocess" for doc_file in doc_files: with open(doc_file, "r", encoding="utf-8") as f: content = f.read() if fast_doc_string not in content: # add the fast image processor to the doc updated_content = content.replace( base_doc_string, base_doc_string + "\n\n" + fast_doc_string, ) # write the updated content with open(doc_file, "w", encoding="utf-8") as f: f.write(updated_content) def add_fast_image_processor_to_tests(fast_image_processor_name: str, model_name: str): """ Add the fast image processor to the image processing tests. """ tests_path = REPO_PATH / "tests" / "models" / model_name test_file = tests_path / f"test_image_processing_{model_name}.py" if not os.path.exists(test_file): logger.warning(f"No test file found for {model_name}. Skipping.") return with open(test_file, "r", encoding="utf-8") as f: content = f.read() # add is_torchvision_available import to the imports # Regex to match import statements from transformers.utils pattern = r""" from\s+transformers\.utils\s+import\s+ (?: # Non-capturing group for either: ([\w, ]+) # 1. Single-line imports (e.g., 'a, b') | # OR \((.*?)\) # 2. Multi-line imports (e.g., '(a, ... b)') ) """ regex = re.compile(pattern, re.VERBOSE | re.DOTALL) def replacement_function(match): # Extract existing imports existing_imports = (match.group(1) or match.group(2)).split(",") existing_imports = existing_imports[:-1] if existing_imports[-1] == "\n" else existing_imports existing_imports = [imp.strip() for imp in existing_imports] # Add the new import if not already present if "is_torchvision_available" not in existing_imports: existing_imports.append("is_torchvision_available") existing_imports.sort() # Rebuild the import statement if match.group(1): # Single-line import updated_imports = ", ".join(existing_imports) else: # Multi-line import updated_imports = "(\n " + ",\n ".join(existing_imports) + ",\n)" return f"from transformers.utils import {updated_imports}" # Replace all matches in the file content updated_content = regex.sub(replacement_function, content) # add the fast image processor to the imports base_import_string = f" from transformers import {fast_image_processor_name[:-4]}" fast_import_string = ( f" if is_torchvision_available():\n from transformers import {fast_image_processor_name}" ) if fast_import_string not in updated_content: updated_content = updated_content.replace(base_import_string, base_import_string + "\n\n" + fast_import_string) # get line starting with " image_processing_class = " and add a line after it starting with " fast_image_processing_class = " image_processing_class_line = re.search(r" image_processing_class = .*", updated_content) if not image_processing_class_line: logger.warning(f"Couldn't find the 'image_processing_class' line in {test_file}. Skipping.") return fast_image_processing_class_line = ( f" fast_image_processing_class = {fast_image_processor_name} if is_torchvision_available() else None" ) if " fast_image_processing_class = " not in updated_content: updated_content = updated_content.replace( image_processing_class_line.group(0), image_processing_class_line.group(0) + "\n" + fast_image_processing_class_line, ) # write the updated content with open(test_file, "w", encoding="utf-8") as f: f.write(updated_content) def get_fast_image_processing_content_header(content: str) -> str: """ Get the header of the slow image processor file. """ # get all the commented lines at the beginning of the file content_header = re.search(r"^# coding=utf-8\n(#[^\n]*\n)*", content, re.MULTILINE) if not content_header: logger.warning("Couldn't find the content header in the slow image processor file. Using a default header.") return ( f"# coding=utf-8\n" f"# Copyright {CURRENT_YEAR} The HuggingFace Team. All rights reserved.\n" f"#\n" f'# Licensed under the Apache License, Version 2.0 (the "License");\n' f"# you may not use this file except in compliance with the License.\n" f"# You may obtain a copy of the License at\n" f"#\n" f"# http://www.apache.org/licenses/LICENSE-2.0\n" f"#\n" f"# Unless required by applicable law or agreed to in writing, software\n" f'# distributed under the License is distributed on an "AS IS" BASIS,\n' f"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" f"# See the License for the specific language governing permissions and\n" f"# limitations under the License.\n" f"\n" ) content_header = content_header.group(0) # replace the year in the copyright content_header = re.sub(r"# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content_header) # get the line starting with """Image processor in content if it exists match = re.search(r'^"""Image processor.*$', content, re.MULTILINE) if match: content_header += match.group(0).replace("Image processor", "Fast Image processor") return content_header def write_default_fast_image_processor_file( fast_image_processing_module_file: str, fast_image_processor_name: str, content_base_file: str ): """ Write a default fast image processor file. Used when encountering a problem while parsing the slow image processor file. """ imports = "\n\nfrom ...image_processing_utils_fast import BaseImageProcessorFast\n\n\n" content_header = get_fast_image_processing_content_header(content_base_file) content_base_file = ( f"class {fast_image_processor_name}(BaseImageProcessorFast):\n" " # To be implemented\n" " resample = None\n" " image_mean = None\n" " image_std = None\n" " size = None\n" " default_to_square = None\n" " crop_size = None\n" " do_resize = None\n" " do_center_crop = None\n" " do_rescale = None\n" " do_normalize = None\n" " do_convert_rgb = None\n\n\n" f'__all__ = ["{fast_image_processor_name}"]\n' ) content = content_header + imports + content_base_file with open(fast_image_processing_module_file, "w", encoding="utf-8") as f: f.write(content) def add_fast_image_processor_file( fast_image_processing_module_file: str, fast_image_processor_name: str, content_base_file: str ): """ Add the fast image processor file to the model's folder. """ # if the file already exists, do nothing if os.path.exists(fast_image_processing_module_file): print(f"{fast_image_processing_module_file} already exists. Skipping.") return regex = rf"class {fast_image_processor_name[:-4]}.*?(\n\S|$)" match = re.search(regex, content_base_file, re.DOTALL) if not match: print(f"Couldn't find the {fast_image_processor_name[:-4]} class in {fast_image_processing_module_file}") print("Creating a new file with the default content.") return write_default_fast_image_processor_file( fast_image_processing_module_file, fast_image_processor_name, content_base_file ) # Exclude the last unindented line slow_class_content = match.group(0).rstrip() # get default args: # find the __init__ block which start with def __init__ and ends with def match = re.search(r"def __init__.*?def ", slow_class_content, re.DOTALL) if not match: print( f"Couldn't find the __init__ block for {fast_image_processor_name[:-4]} in {fast_image_processing_module_file}" ) print("Creating a new file with the default content.") return write_default_fast_image_processor_file( fast_image_processing_module_file, fast_image_processor_name, content_base_file ) init = match.group(0) init_signature_block = init.split(")")[0] arg_names = init_signature_block.split(":") arg_names = [arg_name.split("\n")[-1].strip() for arg_name in arg_names] # get the default values default_args = re.findall(r"= (.*?)(?:,|\))", init_signature_block) # build default args dict default_args_dict = dict(zip(arg_names, default_args)) pattern_default_size = r"size = size if size is not None else\s+(.*)" match_default_size = re.findall(pattern_default_size, init) default_args_dict["size"] = match_default_size[0] if match_default_size else None pattern_default_crop_size = r"crop_size = crop_size if crop_size is not None else\s+(.*)" match_default_crop_size = re.findall(pattern_default_crop_size, init) default_args_dict["crop_size"] = match_default_crop_size[0] if match_default_crop_size else None pattern_default_image_mean = r"self.image_mean = image_mean if image_mean is not None else\s+(.*)" match_default_image_mean = re.findall(pattern_default_image_mean, init) default_args_dict["image_mean"] = match_default_image_mean[0] if match_default_image_mean else None pattern_default_image_std = r"self.image_std = image_std if image_std is not None else\s+(.*)" match_default_image_std = re.findall(pattern_default_image_std, init) default_args_dict["image_std"] = match_default_image_std[0] if match_default_image_std else None default_args_dict["default_to_square"] = False if "(size, default_to_square=False" in init else None content_header = get_fast_image_processing_content_header(content_base_file) content_base_file = ( f"@auto_docstring\n" f"class {fast_image_processor_name}(BaseImageProcessorFast):\n" " # This generated class can be used as a starting point for the fast image processor.\n" " # if the image processor is only used for simple augmentations, such as resizing, center cropping, rescaling, or normalizing,\n" " # only the default values should be set in the class.\n" " # If the image processor requires more complex augmentations, methods from BaseImageProcessorFast can be overridden.\n" " # In most cases, only the `_preprocess` method should be overridden.\n\n" " # For an example of a fast image processor requiring more complex augmentations, see `LlavaNextImageProcessorFast`.\n\n" " # Default values should be checked against the slow image processor\n" " # None values left after checking can be removed\n" f" resample = {default_args_dict.get('resample')}\n" f" image_mean = {default_args_dict.get('image_mean')}\n" f" image_std = {default_args_dict.get('image_std')}\n" f" size = {default_args_dict.get('size')}\n" f" default_to_square = {default_args_dict.get('default_to_square')}\n" f" crop_size = {default_args_dict.get('crop_size')}\n" f" do_resize = {default_args_dict.get('do_resize')}\n" f" do_center_crop = {default_args_dict.get('do_center_crop')}\n" f" do_rescale = {default_args_dict.get('do_rescale')}\n" f" do_normalize = {default_args_dict.get('do_normalize')}\n" f" do_convert_rgb = {default_args_dict.get('do_convert_rgb')}\n\n\n" f'__all__ = ["{fast_image_processor_name}"]\n' ) imports = "\n\nfrom ...image_processing_utils_fast import BaseImageProcessorFast\n" image_utils_imports = [] if default_args_dict.get("resample") is not None and "PILImageResampling" in default_args_dict.get("resample"): image_utils_imports.append("PILImageResampling") if default_args_dict.get("image_mean") is not None and not any( char.isdigit() for char in default_args_dict.get("image_mean") ): image_utils_imports.append(default_args_dict.get("image_mean")) if default_args_dict.get("image_std") is not None and not any( char.isdigit() for char in default_args_dict.get("image_std") ): image_utils_imports.append(default_args_dict.get("image_std")) if image_utils_imports: # sort imports image_utils_imports.sort() imports += f"from ...image_utils import {', '.join(image_utils_imports)}\n" imports += "from ...utils import auto_docstring\n" content = content_header + imports + "\n\n" + content_base_file with open(fast_image_processing_module_file, "w", encoding="utf-8") as f: f.write(content) def add_fast_image_processor(model_name: str): """ Add the necessary references to the fast image processor in the transformers package, and create the fast image processor file in the model's folder. """ model_module = TRANSFORMERS_PATH / "models" / model_name image_processing_module_file = list(model_module.glob("image_processing*.py")) if not image_processing_module_file: raise ValueError(f"No image processing module found in {model_module}") elif len(image_processing_module_file) > 1: for file_name in image_processing_module_file: if not str(file_name).endswith("_fast.py"): image_processing_module_file = str(file_name) break else: image_processing_module_file = str(image_processing_module_file[0]) with open(image_processing_module_file, "r", encoding="utf-8") as f: content_base_file = f.read() # regex to find object starting with "class " and ending with "ImageProcessor", including "ImageProcessor" in the match image_processor_name = re.findall(r"class (\w*ImageProcessor)", content_base_file) if not image_processor_name: raise ValueError(f"No ImageProcessor class found in {image_processing_module_file}") elif len(image_processor_name) > 1: raise ValueError(f"Multiple ImageProcessor classes found in {image_processing_module_file}") image_processor_name = image_processor_name[0] fast_image_processor_name = image_processor_name + "Fast" fast_image_processing_module_file = image_processing_module_file.replace(".py", "_fast.py") print(f"Adding {fast_image_processor_name} to {fast_image_processing_module_file}") add_fast_image_processor_to_model_init( fast_image_processing_module_file=fast_image_processing_module_file, fast_image_processor_name=fast_image_processor_name, model_name=model_name, ) add_fast_image_processor_to_auto( image_processor_name=image_processor_name, fast_image_processor_name=fast_image_processor_name, ) add_fast_image_processor_to_doc( fast_image_processor_name=fast_image_processor_name, model_name=model_name, ) add_fast_image_processor_to_tests( fast_image_processor_name=fast_image_processor_name, model_name=model_name, ) add_fast_image_processor_file( fast_image_processing_module_file=fast_image_processing_module_file, fast_image_processor_name=fast_image_processor_name, content_base_file=content_base_file, ) def add_new_model_like_command_factory(args: Namespace): return AddFastImageProcessorCommand(model_name=args.model_name) class AddFastImageProcessorCommand(BaseTransformersCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): add_fast_image_processor_parser = parser.add_parser("add-fast-image-processor") add_fast_image_processor_parser.add_argument( "--model-name", type=str, required=True, help="The name of the folder containing the model's implementation.", ) add_fast_image_processor_parser.set_defaults(func=add_new_model_like_command_factory) def __init__(self, model_name: str, *args): self.model_name = model_name def run(self): add_fast_image_processor(model_name=self.model_name)