# coding=utf-8 # Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch EfficientNet model.""" import math from typing import Optional, Union import torch from torch import nn from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutputWithNoAttention, BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, ) from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_efficientnet import EfficientNetConfig logger = logging.get_logger(__name__) def round_filters(config: EfficientNetConfig, num_channels: int): r""" Round number of filters based on depth multiplier. """ divisor = config.depth_divisor num_channels *= config.width_coefficient new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_dim < 0.9 * num_channels: new_dim += divisor return int(new_dim) def correct_pad(kernel_size: Union[int, tuple], adjust: bool = True): r""" Utility function to get the tuple padding value for the depthwise convolution. Args: kernel_size (`int` or `tuple`): Kernel size of the convolution layers. adjust (`bool`, *optional*, defaults to `True`): Adjusts padding value to apply to right and bottom sides of the input. """ if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) correct = (kernel_size[0] // 2, kernel_size[1] // 2) if adjust: return (correct[1] - 1, correct[1], correct[0] - 1, correct[0]) else: return (correct[1], correct[1], correct[0], correct[0]) class EfficientNetEmbeddings(nn.Module): r""" A module that corresponds to the stem module of the original work. """ def __init__(self, config: EfficientNetConfig): super().__init__() self.out_dim = round_filters(config, 32) self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1)) self.convolution = nn.Conv2d( config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False ) self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum) self.activation = ACT2FN[config.hidden_act] def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: features = self.padding(pixel_values) features = self.convolution(features) features = self.batchnorm(features) features = self.activation(features) return features class EfficientNetDepthwiseConv2d(nn.Conv2d): def __init__( self, in_channels, depth_multiplier=1, kernel_size=3, stride=1, padding=0, dilation=1, bias=True, padding_mode="zeros", ): out_channels = in_channels * depth_multiplier super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=bias, padding_mode=padding_mode, ) class EfficientNetExpansionLayer(nn.Module): r""" This corresponds to the expansion phase of each block in the original implementation. """ def __init__(self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int): super().__init__() self.expand_conv = nn.Conv2d( in_channels=in_dim, out_channels=out_dim, kernel_size=1, padding="same", bias=False, ) self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps) self.expand_act = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: # Expand phase hidden_states = self.expand_conv(hidden_states) hidden_states = self.expand_bn(hidden_states) hidden_states = self.expand_act(hidden_states) return hidden_states class EfficientNetDepthwiseLayer(nn.Module): r""" This corresponds to the depthwise convolution phase of each block in the original implementation. """ def __init__( self, config: EfficientNetConfig, in_dim: int, stride: int, kernel_size: int, adjust_padding: bool, ): super().__init__() self.stride = stride conv_pad = "valid" if self.stride == 2 else "same" padding = correct_pad(kernel_size, adjust=adjust_padding) self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding) self.depthwise_conv = EfficientNetDepthwiseConv2d( in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False ) self.depthwise_norm = nn.BatchNorm2d( num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum ) self.depthwise_act = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: # Depthwise convolution if self.stride == 2: hidden_states = self.depthwise_conv_pad(hidden_states) hidden_states = self.depthwise_conv(hidden_states) hidden_states = self.depthwise_norm(hidden_states) hidden_states = self.depthwise_act(hidden_states) return hidden_states class EfficientNetSqueezeExciteLayer(nn.Module): r""" This corresponds to the Squeeze and Excitement phase of each block in the original implementation. """ def __init__(self, config: EfficientNetConfig, in_dim: int, expand_dim: int, expand: bool = False): super().__init__() self.dim = expand_dim if expand else in_dim self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio)) self.squeeze = nn.AdaptiveAvgPool2d(output_size=1) self.reduce = nn.Conv2d( in_channels=self.dim, out_channels=self.dim_se, kernel_size=1, padding="same", ) self.expand = nn.Conv2d( in_channels=self.dim_se, out_channels=self.dim, kernel_size=1, padding="same", ) self.act_reduce = ACT2FN[config.hidden_act] self.act_expand = nn.Sigmoid() def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: inputs = hidden_states hidden_states = self.squeeze(hidden_states) hidden_states = self.reduce(hidden_states) hidden_states = self.act_reduce(hidden_states) hidden_states = self.expand(hidden_states) hidden_states = self.act_expand(hidden_states) hidden_states = torch.mul(inputs, hidden_states) return hidden_states class EfficientNetFinalBlockLayer(nn.Module): r""" This corresponds to the final phase of each block in the original implementation. """ def __init__( self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool ): super().__init__() self.apply_dropout = stride == 1 and not id_skip self.project_conv = nn.Conv2d( in_channels=in_dim, out_channels=out_dim, kernel_size=1, padding="same", bias=False, ) self.project_bn = nn.BatchNorm2d( num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum ) self.dropout = nn.Dropout(p=drop_rate) def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor: hidden_states = self.project_conv(hidden_states) hidden_states = self.project_bn(hidden_states) if self.apply_dropout: hidden_states = self.dropout(hidden_states) hidden_states = hidden_states + embeddings return hidden_states class EfficientNetBlock(nn.Module): r""" This corresponds to the expansion and depthwise convolution phase of each block in the original implementation. Args: config ([`EfficientNetConfig`]): Model configuration class. in_dim (`int`): Number of input channels. out_dim (`int`): Number of output channels. stride (`int`): Stride size to be used in convolution layers. expand_ratio (`int`): Expand ratio to set the output dimensions for the expansion and squeeze-excite layers. kernel_size (`int`): Kernel size for the depthwise convolution layer. drop_rate (`float`): Dropout rate to be used in the final phase of each block. id_skip (`bool`): Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase of each block. Set to `True` for the first block of each stage. adjust_padding (`bool`): Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution operation, set to `True` for inputs with odd input sizes. """ def __init__( self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int, expand_ratio: int, kernel_size: int, drop_rate: float, id_skip: bool, adjust_padding: bool, ): super().__init__() self.expand_ratio = expand_ratio self.expand = self.expand_ratio != 1 expand_in_dim = in_dim * expand_ratio if self.expand: self.expansion = EfficientNetExpansionLayer( config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride ) self.depthwise_conv = EfficientNetDepthwiseLayer( config=config, in_dim=expand_in_dim if self.expand else in_dim, stride=stride, kernel_size=kernel_size, adjust_padding=adjust_padding, ) self.squeeze_excite = EfficientNetSqueezeExciteLayer( config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand ) self.projection = EfficientNetFinalBlockLayer( config=config, in_dim=expand_in_dim if self.expand else in_dim, out_dim=out_dim, stride=stride, drop_rate=drop_rate, id_skip=id_skip, ) def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: embeddings = hidden_states # Expansion and depthwise convolution phase if self.expand_ratio != 1: hidden_states = self.expansion(hidden_states) hidden_states = self.depthwise_conv(hidden_states) # Squeeze and excite phase hidden_states = self.squeeze_excite(hidden_states) hidden_states = self.projection(embeddings, hidden_states) return hidden_states class EfficientNetEncoder(nn.Module): r""" Forward propagates the embeddings through each EfficientNet block. Args: config ([`EfficientNetConfig`]): Model configuration class. """ def __init__(self, config: EfficientNetConfig): super().__init__() self.config = config self.depth_coefficient = config.depth_coefficient def round_repeats(repeats): # Round number of block repeats based on depth multiplier. return int(math.ceil(self.depth_coefficient * repeats)) num_base_blocks = len(config.in_channels) num_blocks = sum(round_repeats(n) for n in config.num_block_repeats) curr_block_num = 0 blocks = [] for i in range(num_base_blocks): in_dim = round_filters(config, config.in_channels[i]) out_dim = round_filters(config, config.out_channels[i]) stride = config.strides[i] kernel_size = config.kernel_sizes[i] expand_ratio = config.expand_ratios[i] for j in range(round_repeats(config.num_block_repeats[i])): id_skip = j == 0 stride = 1 if j > 0 else stride in_dim = out_dim if j > 0 else in_dim adjust_padding = curr_block_num not in config.depthwise_padding drop_rate = config.drop_connect_rate * curr_block_num / num_blocks block = EfficientNetBlock( config=config, in_dim=in_dim, out_dim=out_dim, stride=stride, kernel_size=kernel_size, expand_ratio=expand_ratio, drop_rate=drop_rate, id_skip=id_skip, adjust_padding=adjust_padding, ) blocks.append(block) curr_block_num += 1 self.blocks = nn.ModuleList(blocks) self.top_conv = nn.Conv2d( in_channels=out_dim, out_channels=round_filters(config, 1280), kernel_size=1, padding="same", bias=False, ) self.top_bn = nn.BatchNorm2d( num_features=config.hidden_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum ) self.top_activation = ACT2FN[config.hidden_act] def forward( self, hidden_states: torch.FloatTensor, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> BaseModelOutputWithNoAttention: all_hidden_states = (hidden_states,) if output_hidden_states else None for block in self.blocks: hidden_states = block(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) hidden_states = self.top_conv(hidden_states) hidden_states = self.top_bn(hidden_states) hidden_states = self.top_activation(hidden_states) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) return BaseModelOutputWithNoAttention( last_hidden_state=hidden_states, hidden_states=all_hidden_states, ) @auto_docstring class EfficientNetPreTrainedModel(PreTrainedModel): config: EfficientNetConfig base_model_prefix = "efficientnet" main_input_name = "pixel_values" _no_split_modules = [] def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() @auto_docstring class EfficientNetModel(EfficientNetPreTrainedModel): def __init__(self, config: EfficientNetConfig): super().__init__(config) self.config = config self.embeddings = EfficientNetEmbeddings(config) self.encoder = EfficientNetEncoder(config) # Final pooling layer if config.pooling_type == "mean": self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True) elif config.pooling_type == "max": self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True) else: raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}") # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") embedding_output = self.embeddings(pixel_values) encoder_outputs = self.encoder( embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # Apply pooling last_hidden_state = encoder_outputs[0] pooled_output = self.pooler(last_hidden_state) # Reshape (batch_size, 1280, 1 , 1) -> (batch_size, 1280) pooled_output = pooled_output.reshape(pooled_output.shape[:2]) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndNoAttention( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, ) @auto_docstring( custom_intro=""" EfficientNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for ImageNet. """ ) class EfficientNetForImageClassification(EfficientNetPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.efficientnet = EfficientNetModel(config) # Classifier head self.dropout = nn.Dropout(p=config.dropout_rate) self.classifier = nn.Linear(config.hidden_dim, self.num_labels) if self.num_labels > 0 else nn.Identity() # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutputWithNoAttention]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.efficientnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) pooled_output = outputs.pooler_output if return_dict else outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: loss = self.loss_function(labels, logits, self.config) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return ImageClassifierOutputWithNoAttention( loss=loss, logits=logits, hidden_states=outputs.hidden_states, ) __all__ = ["EfficientNetForImageClassification", "EfficientNetModel", "EfficientNetPreTrainedModel"]