from torch import nn
from torch.nn import functional as F
from torchvision import models
"""
reference = [
"https://github.com/GeorgeSeif/Semantic-Segmentation-Suite/blob/master/models/refine_net.py",
"https://arxiv.org/pdf/1611.06612.pdf",
"Building Extraction in very High resolution Imagery by Dense Attention Networks",
]
backbone_features - represent the output of backbone network
refine_block_features - represent the output of Refine Block
"""
[docs]
def convolution_3x3(in_planes, out_planes, stride=1, padding=1, bias=True):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, bias=bias
)
[docs]
def convolution_1x1(in_planes, out_planes, stride=1, padding=0, bias=True):
"""1x1 convolution"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias
)
[docs]
class SpatialAttentionFusionModule(nn.Module):
"""
Inspired by the attention mechanism, a spatial attention fusion module was designed to enhance useful
low-level feature information and remove noise to avoid over using low-level features
"""
[docs]
def __init__(self):
super().__init__()
[docs]
def forward(self, low_level_features, high_level_features):
"""
:param low_level_features: features extracted from backbone
:param high_level_features: up sampled features
:return:
"""
high_level_features_sigmoid = high_level_features.sigmoid()
weighted_low_level_features = high_level_features_sigmoid * low_level_features
feature_fusion = weighted_low_level_features + high_level_features
return feature_fusion
[docs]
class ResidualConvolutionUnit(nn.Module):
"""
Section 3.2:
The first part of each RefineNet block consists of an adaptive convolution set that mainly fine tunes
the pre trained the ResNet weights
"""
[docs]
def __init__(self, in_planes, out_planes):
super().__init__()
self.non_linearity = nn.ReLU(inplace=True)
self.convolution_layer_1 = convolution_3x3(in_planes, out_planes)
self.convolution_layer_2 = convolution_3x3(out_planes, out_planes)
[docs]
def forward(self, x):
residual = x
x = self.non_linearity(x)
x = self.convolution_layer_1(x)
x = self.non_linearity(x)
x = self.convolution_layer_2(x)
x = residual + x
return x
[docs]
class MultiResolutionFusion(nn.Module):
[docs]
def __init__(self, in_planes, out_planes, fusion_module):
super().__init__()
self.convolution_layer_lower_inputs = convolution_3x3(in_planes, out_planes)
self.convolution_layer_higher_inputs = convolution_3x3(out_planes, out_planes)
if fusion_module:
self.fusion_module = SpatialAttentionFusionModule()
else:
self.fusion_module = None
[docs]
def forward(self, backbone_features, refine_block_features=None):
if refine_block_features is None:
"""
Suggests RefineNet-4
"""
return self.convolution_layer_higher_inputs(backbone_features)
else:
backbone_features = self.convolution_layer_higher_inputs(backbone_features)
refine_block_features = self.convolution_layer_lower_inputs(
refine_block_features
)
refine_block_features = F.interpolate(
refine_block_features,
scale_factor=2,
mode="bilinear",
align_corners=True,
)
if self.fusion_module is not None:
return self.fusion_module(backbone_features, refine_block_features)
else:
return refine_block_features + backbone_features
[docs]
class ChainedResidualPooling(nn.Module):
"""
Section-1:
Chained residual pooling is able to capture background context from a large image region
"""
[docs]
def __init__(self, in_planes, out_planes):
super().__init__()
self.non_linearity = nn.ReLU(inplace=True)
self.convolution_layer_1 = convolution_3x3(
in_planes=in_planes, out_planes=out_planes
)
self.max_pooling_layer = nn.MaxPool2d((5, 5), stride=1, padding=2)
[docs]
def forward(self, x):
x_non_linearity = self.non_linearity(x)
first_pass = self.max_pooling_layer(x_non_linearity)
first_pass = self.convolution_layer_1(first_pass)
intermediate_sum = first_pass + x_non_linearity
second_pass = self.max_pooling_layer(first_pass)
second_pass = self.convolution_layer_1(second_pass)
x = second_pass + intermediate_sum
return x
[docs]
class RefineBlock(nn.Module):
[docs]
def __init__(self, in_planes, out_planes, fusion_module):
super().__init__()
self.residual_convolution_unit = ResidualConvolutionUnit(out_planes, out_planes)
self.multi_resolution_fusion = MultiResolutionFusion(
in_planes, out_planes, fusion_module
)
self.chained_residual_pooling = ChainedResidualPooling(out_planes, out_planes)
[docs]
def forward(self, backbone_features, refine_block_features=None):
"""
:param backbone_features: input from backbone network
:param refine_block_features: input from refine net block
:return:
"""
if refine_block_features is None:
"""
Suggests RefineNet-4
"""
"""
Section 3.2:
The first part of each RefineNet block consists of an adaptive convolution set that mainly fine tunes
the pre trained the ResNet weights
"""
x = self.residual_convolution_unit(backbone_features)
x = self.residual_convolution_unit(x)
"""
section 3.2 -
Multi-resolution fusion :
If there is only one input path (e.g , the case of RefineNet-4 the input will directly go thorough
"""
x = self.multi_resolution_fusion(x)
x = self.chained_residual_pooling(x)
x = self.residual_convolution_unit(x)
return x
else:
"""
Section 3.2:
The first part of each RefineNet block consists of an adaptive convolution set that mainly fine tunes
the pre trained the ResNet weights
"""
x = self.residual_convolution_unit(backbone_features)
x = self.residual_convolution_unit(x)
x = self.multi_resolution_fusion(x, refine_block_features)
x = self.chained_residual_pooling(x)
x = self.residual_convolution_unit(x)
return x
[docs]
class ReFineNet(nn.Module):
[docs]
def __init__(
self,
res_net_to_use,
pre_trained_image_net,
top_layers_trainable=True,
num_classes=1,
fusion_module=False,
):
super().__init__()
self.num_classes = num_classes
self.fusion_module = fusion_module
res_net = getattr(models, res_net_to_use)(pretrained=pre_trained_image_net)
if not top_layers_trainable:
for param in res_net.parameters():
param.requires_grad = False
self.layer0 = nn.Sequential(
res_net.conv1, res_net.bn1, res_net.relu, res_net.maxpool
)
self.layer1 = res_net.layer1
self.layer2 = res_net.layer2
self.layer3 = res_net.layer3
self.layer4 = res_net.layer4
if res_net_to_use == "resnet50":
layers_features = [256, 512, 1024, 2048]
elif res_net_to_use == "resnet34":
layers_features = [64, 128, 256, 512]
else:
raise NotImplementedError
"""
section 3.1 -
In practice each ResNet output is passed through one convolution layer to adapt the dimensionality
"""
self.convolution_layer_4_dim_reduction = convolution_1x1(
in_planes=layers_features[-1], out_planes=512
)
self.convolution_layer_3_dim_reduction = convolution_1x1(
in_planes=layers_features[-2], out_planes=256
)
self.convolution_layer_2_dim_reduction = convolution_1x1(
in_planes=layers_features[-3], out_planes=256
)
self.convolution_layer_1_dim_reduction = convolution_1x1(
in_planes=layers_features[-4], out_planes=256
)
self.refine_block_4 = RefineBlock(
in_planes=512, out_planes=512, fusion_module=self.fusion_module
)
self.refine_block_3 = RefineBlock(
in_planes=512, out_planes=256, fusion_module=self.fusion_module
)
self.refine_block_2 = RefineBlock(
in_planes=256, out_planes=256, fusion_module=self.fusion_module
)
self.refine_block_1 = RefineBlock(
in_planes=256, out_planes=256, fusion_module=self.fusion_module
)
"""
Section 3.2:
The final step of Each RefineNet block is another residual convolution unit .
This results in a sequence of three RCU between each block.
To reflect this behaviour in the last RefineNet-1 block, we place two additional RCU
"""
self.residual_convolution_unit = ResidualConvolutionUnit(
in_planes=256, out_planes=256
)
self.final_layer = convolution_1x1(in_planes=256, out_planes=self.num_classes)
[docs]
def forward(self, input_feature):
layer_0_output = self.layer0(input_feature)
layer_1_output = self.layer1(layer_0_output) # 1/4
layer_2_output = self.layer2(layer_1_output) # 1/8
layer_3_output = self.layer3(layer_2_output) # 1/16
layer_4_output = self.layer4(layer_3_output) # 1/32
backbone_layer_4 = self.convolution_layer_4_dim_reduction(
layer_4_output
) # 1/32
backbone_layer_3 = self.convolution_layer_3_dim_reduction(
layer_3_output
) # 1/16
backbone_layer_2 = self.convolution_layer_2_dim_reduction(layer_2_output) # 1/8
backbone_layer_1 = self.convolution_layer_1_dim_reduction(layer_1_output) # 1/4
refine_block_4 = self.refine_block_4(backbone_layer_4)
refine_block_3 = self.refine_block_3(backbone_layer_3, refine_block_4)
refine_block_3 = self.refine_block_2(backbone_layer_2, refine_block_3)
refine_block_1 = self.refine_block_1(backbone_layer_1, refine_block_3)
"""
Section 3.2:
The final step of Each RefineNet block is another residual convolution unit .
This results in a sequence of three RCU between each block.
To reflect this behaviour in the last RefineNet-1 block, we place two additional RCU
"""
residual_convolution_unit = self.residual_convolution_unit(refine_block_1)
residual_convolution_unit = self.residual_convolution_unit(
residual_convolution_unit
)
final_map = self.final_layer(residual_convolution_unit)
final_map = F.interpolate(
final_map, scale_factor=4, mode="bilinear", align_corners=True
)
return final_map