biasutti2019riu

Summarizing quotes from the paper:

1. Introduction

  • “we propose RIU-Net (for Range-Image U-Net), the adaptation of U-Net […] to the semantic segmentation of 3D LiDAR point clouds.” ### 2. Related works
  • “Recently, Wu et al. proposed SqueezeSeg, a novel approach for the semantic segmentation of a LiDAR point cloud represented as a spherical range-image.” ### 3. Methodology
  • “The method consists in feeding the U-Net architectures with 2-channels [range-images] encoding range and elevation.” #### 3.A Input of the Network
  • “we use a range-image named u of 512 × 64px with two channels: the depth towards the sensor and the elevation.”
  • “We propose to identify [empty pixels (pixels that are considered invalid by the LIDAR sensor)] using a binary mask m equal to 0 for empty pixels and to 1 otherwise.” #### 3.B Architecture Figure from the original paper (Fig.2).
  • “the encoder consists in the repeated application of two 3×3 convolutions followed by a rectified linear unit (ReLU) and a 2×2 max-pooling layer that downsamples the input by a factor 2. Each time a downsampling is done, the number of features is doubled.”
  • “The decoder consists in upsampling blocks where the input is upsampled using 2 × 2 up-convolutions. Then, concatenation is done between the upsampled feature map and the corresponding feature map of the encoder. After that, two 3 × 3 convolutions are applied followed by a ReLU. This block is repeated until the output of the network matches the dimension of the input.”
  • “the last layer consists in a 1x1 convolution that outputs as many features as the wanted number of possible labels K 1-hot encoded.” #### 3.C Loss function
  • “is defined as the cross-entropy of the softmax of the output of the network.”

E=xΩ1{m(x)>0}w(x)log(pl(x)(x))

  • “[in the equation shown above (corrected from the paper)] we define l(x) the groundtruth label of the x pixel. […] m(x)>0 are the valid pixels [(non-empty pixels)] and w(x) is a weighting function introduced to give more importance to pixels that are close to a separation between two labels, as defined in U-Net.” #### 3.D Training
  • “We train the network with the Adam stochastic gradient optimizer and a learning rate set to 0.001. We also use batch normalization with a momentum of 0.99 to ensure good convergence of the model. Finally, the batch size is set to 8 and the training is stopped after 10 epochs.” ### 4. Experiments
  • “we follow the experimental setup of the SqueezeSeg approach […] which contains 8057 samples for training and 2791 for validation.”
  • “we use the intersection-over-union metric”
  • “we advocate that the proposed model can operate with a frame-rate of 90 frames per second on a single GPU

RIUNet architecture


source

Block

 Block (in_channels, out_channels)

Convolutional block repeatedly used in the RIU-Net encoder and decoder.

Exported source
class Block(Sequential):
    "Convolutional block repeatedly used in the RIU-Net encoder and decoder."
    def __init__(self, in_channels, out_channels):
        super().__init__(OrderedDict([
            (f'conv1', Conv2d(in_channels, out_channels, 3, 1, 1, bias=False, padding_mode='circular')),
            (f'bn1', BatchNorm2d(out_channels, momentum=0.01)),
            (f'relu1', ReLU()),
            (f'conv2', Conv2d(out_channels, out_channels, 3, 1, 1, bias=False, padding_mode='circular')),
            (f'bn2', BatchNorm2d(out_channels, momentum=0.01)),
            (f'relu2', ReLU()),
        ]))
        self.init_params()
    
    def init_params(self):
        for n, p in self.named_parameters():
            if re.search('conv\d\.weight', n):
                kaiming_normal_(p, nonlinearity='relu')

It implements the following architecture:

Input
(bs, in_c, h, w)
Conv(3x3)
in_c -> out_c
BatchNorm2d
ReLU
Conv(3x3)
out_c -> out_c
BatchNorm2d
ReLU
Output
(bs, out_c, h, w)

Here is an example on how to use it:

bs, in_c, out_c, h, w = 1, 5, 64, 64, 512
inp = torch.randn(bs, in_c, h, w)

b = Block(in_c, out_c)
outp = b(inp)
assert outp.shape == (bs, out_c, h, w)
print(outp.shape, f'== ({bs}, {out_c}, {h}, {w})')
torch.Size([1, 64, 64, 512]) == (1, 64, 64, 512)

It initializes the weights from the conv layers following the kaiming_normal_ algorithm in fan_in mode as described in U-Net (page 5).

from matplotlib import pyplot as plt
import matplotlib.colors as mcolors
colors_list = list(mcolors.TABLEAU_COLORS)

def plot_param_dists(net, param_re_pattern, nonlinearity_gain):
    color_idx = 0
    for n, p in net.named_parameters():
        if re.search(param_re_pattern, n):
            x_range = 1.1*p.data.max()
            ## kaiming normal dist
            fan_in = p.shape[1]*p.shape[2]*p.shape[3]
            mu, sigma = 0., np.sqrt(nonlinearity_gain/fan_in)
            x = np.linspace(-x_range, x_range, 100)
            y = ((1./(np.sqrt(2*np.pi)*sigma))*np.exp(-0.5*((1./sigma)*(x - mu))**2))
            plt.plot(x, y, '--', color=colors_list[color_idx], label='Expected '+n)
            ## sampled weight dist
            plt.hist(p.view(-1).data, 30, density=True, alpha=0.5, color=colors_list[color_idx], label='Actual '+n)
            color_idx += 1
    plt.legend();
plot_param_dists(b, 'conv\d\.weight', 2.)

for n, p in b.named_parameters():
    if re.search('conv\d\.weight', n):
        fan_in = p.shape[1]*p.shape[2]*p.shape[3]
        mu, sigma = 0., np.sqrt(2./fan_in)
        p_data = p.view(-1).data
        assert abs(mu - p_data.mean()) < 1e-2
        assert abs(sigma - p_data.std()) < 1e-2

source

Encoder

 Encoder (channels=(5, 64, 128, 256, 512, 1024))

RIU-Net encoder architecture.

Exported source
class Encoder(Module):
    "RIU-Net encoder architecture."
    def __init__(self, channels=(5, 64, 128, 256, 512, 1024)):
        super().__init__()
        self.blocks = ModuleList(
            [Block(channels[i], channels[i+1]) for i in range(len(channels)-1)]
        )
    
    def forward(self, x):
        enc_features = []
        for block in self.blocks:
            x = block(x)
            enc_features.append(x)
            x = F.max_pool2d(x, 2)
        return enc_features

It implements the following architecture:

Input
(bs, 5, h, w)
Block
5 -> 64
MaxPool(2x2)
Block
64 -> 128
MaxPool(2x2)
Block
128 -> 256
MaxPool(2x2)
Block
256 -> 512
MaxPool(2x2)
Block
512 -> 1024
Output
[(bs, 64, h, w),
(bs, 128, h/2, w/2),
(bs, 256, h/4, w/4),
(bs, 512, h/8, w/8),
(bs, 1024, h/16, w/16)]

Here is an example on how to use it:

enc = Encoder()
outp = enc(inp)
[o.shape for o in outp]
[torch.Size([1, 64, 64, 512]),
 torch.Size([1, 128, 32, 256]),
 torch.Size([1, 256, 16, 128]),
 torch.Size([1, 512, 8, 64]),
 torch.Size([1, 1024, 4, 32])]

For the decoder, the paper mentions the application of “up-convolutions”, which were first defined in U-Net as:

“an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”)”

Our implementation changes the 2x2 convolution to a 3x3 one to avoid croping the feature maps to handle shape mismatches with the encoder skip connections.


source

UpConv

 UpConv (in_channels, out_channels)

Up-convolution operation adapted from U-Net.

Exported source
class UpConv(Sequential):
    "Up-convolution operation adapted from [U-Net](https://arxiv.org/abs/1505.04597)."
    def __init__(self, in_channels, out_channels):
        super().__init__(OrderedDict([
            (f'upsample', Upsample(scale_factor=2)),
            (f'conv', Conv2d(in_channels, out_channels, 3, 1, 1, bias=False, padding_mode='circular')),
            (f'bn', BatchNorm2d(out_channels, momentum=0.01)),
        ]))
        self.init_params()
    
    def init_params(self):
        for n, p in self.named_parameters():
            if re.search('conv.weight', n):
                kaiming_normal_(p, nonlinearity='linear')

Here is an example on how to use it:

upc = UpConv(1024, 512)
outp = enc(inp)
print(f'before: {outp[-1].shape}')
outp = upc(outp[-1])
print(f'after: {outp.shape}')
before: torch.Size([1, 1024, 4, 32])
after: torch.Size([1, 512, 8, 64])

source

Decoder

 Decoder (channels=(1024, 512, 256, 128, 64))

RIU-Net decoder architecture.

Exported source
class Decoder(Module):
    "RIU-Net decoder architecture."
    def __init__(self, channels=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.upconvs = ModuleList(
            [UpConv(channels[i], channels[i+1]) for i in range(len(channels)-1)]
        )
        self.blocks = ModuleList(
            [Block(channels[i], channels[i+1]) for i in range(len(channels)-1)]
        )
    
    def forward(self, enc_features):
        x = enc_features[-1]
        for i, (upconv, block) in enumerate(zip(self.upconvs, self.blocks)):
            x = upconv(x)
            x = torch.cat([x, enc_features[-(i+2)]], dim=1)
            x = block(x)
        return x

It implements the following architecture:

Input
[(bs, 1024, h/16, w/16),
(bs, 512, h/8, w/8),
(bs, 256, h/4, w/4),
(bs, 128, h/2, w/2),
(bs, 64, h, w)]
UpConv
1024 -> 512
Block
concat(512,512) -> 512
UpConv
512 -> 256
Block
concat(256,256) -> 256
UpConv
256 -> 128
Block
concat(128,128) -> 128
UpConv
128 -> 64
Block
concat(64,64) -> 64
Output
(bs, 64, h, w)

Here is an example on how to use it:

dec = Decoder()
outp = enc(inp)
fts = dec(outp)
assert fts.shape == (bs, out_c, h, w)
print(fts.shape, f'== ({bs}, {out_c}, {h}, {w})')
torch.Size([1, 64, 64, 512]) == (1, 64, 64, 512)

It initializes the weights from the upconv layers following the kaiming_normal_ algorithm in fan_in mode and nonlinearity set as ‘linear’, since no relu layer is used in the operation.

plot_param_dists(dec, 'upconvs\.\d\.conv\.weight', 1.)

for n, p in dec.named_parameters():
    if re.search('upconvs\.\d\.conv\.weight', n):
        fan_in = p.shape[1]*p.shape[2]*p.shape[3]
        mu, sigma = 0., np.sqrt(1./fan_in)
        p_data = p.view(-1).data
        assert abs(mu - p_data.mean()) < 1e-3
        assert abs(sigma - p_data.std()) < 1e-3

source

RIUNet

 RIUNet (in_channels=5, hidden_channels=(64, 128, 256, 512, 1024),
         n_classes=20)

RIU-Net complete architecture.

Exported source
class RIUNet(Module):
    "RIU-Net complete architecture."
    def __init__(self, in_channels=5, hidden_channels=(64, 128, 256, 512, 1024), n_classes=20):
        super().__init__()
        self.n_classes = n_classes
        self.input_norm = BatchNorm2d(in_channels, affine=False, momentum=None)
        self.backbone = Sequential(OrderedDict([
            (f'enc', Encoder((in_channels, *hidden_channels))),
            (f'dec', Decoder(hidden_channels[::-1]))
        ]))
        self.head = Conv2d(hidden_channels[0], n_classes, 1)
        self.init_params()

    def init_params(self):
        for n, p in self.named_parameters():
            if re.search('head\.weight', n):
                normal_(p, std=1e-2)
            if re.search('head\.bias', n):
                zeros_(p)
    
    def forward(self, x):
        x = self.input_norm(x)
        features = self.backbone(x)
        prediction = self.head(features)
        
        return prediction

We slightly changed it to accept 5 input channels (i.e. x, y, z, depth and reflectance) instead of the 2 (depth and elevation) proposed in the original paper.

It implements the following architecture:

Input
(bs, 5, h, w)
Encoder
Decoder
Conv(1x1)
64 -> 20
Output
(bs, 20, h, w)

Here is an example on how to use it:

n_classes=20
model = RIUNet()
logits = model(inp)
assert logits.shape == (bs, n_classes, h, w)
print(logits.shape, f'== ({bs}, {n_classes}, {h}, {w})')
torch.Size([1, 20, 64, 512]) == (1, 20, 64, 512)

It initializes the weights from the classification head from a normal distribution with a standard deviation of 1e-2. The motivation is to reduce any random bias on the outputs of the untrained model.

plot_param_dists(model, 'head\.weight', 0.0064)

for n, p in model.named_parameters():
    if re.search('head\.weight', n):
        fan_in = p.shape[1]*p.shape[2]*p.shape[3]
        mu, sigma = 0., 1e-2
        p_data = p.view(-1).data
        assert abs(mu - p_data.mean()) < 1e-2
        assert abs(sigma - p_data.std()) < 1e-2

Loss function

The proposed equation for the loss function is the following.

E=xΩw(x)log(pl(x)(x))1{m(x)>0}

The factor w is motivated as a way “to give more importance to pixels that are close to a separation between two labels”. It was first defined in U-Net as follows.

w(x)=wcx+w0exp((d1(x)+d2(x))22σ2)

From U-Net:

“where wc […] is the weight map to balance the class frequencies, d1 […] denotes the distance to the border of the nearest cell and d2 […] the distance to the border of the second nearest cell.

While this particular equation for w can be seen as very specific for the original biomedical application in the Unet paper, its motivation is generally valid for any image semantic segmentation task such as ours. Hence, we can rewrite the equation for w as follows.

w(x)=wc+λwb

Similarly to the previous equation, wc accounts for class imbalance, but the second term is rewritten as a general wb factor that should account for the boundaries of the semantic maps.

For now, we only implemented the wc factor. We leave the wb factor and some experimentation to evaluate its impact in the final model as TODO items.

Since the weighing and masking of the cross entropy loss is already implemented through the parameters weight and ignore_index in Pytorch’s CrossEntropyLoss module, we implement our own wrapper simply for convenience.


source

WeightedMaskedCELoss

 WeightedMaskedCELoss (weight)

Convenient wrapper for the CrossEntropyLoss module with a weight and ignore_index paremeters already set.

Exported source
class WeightedMaskedCELoss(Module):
    "Convenient wrapper for the CrossEntropyLoss module with a `weight` and `ignore_index` paremeters already set."
    def __init__(self, weight):
        super().__init__()
        self.ignore_index = -1
        self.wmCE = CrossEntropyLoss(weight=weight, ignore_index=self.ignore_index)

    def forward(self, pred, label, mask):
        label[~mask] = self.ignore_index
        loss = self.wmCE(pred, label)
        return loss

LightningModule for standard experiments


source

log_activations

 log_activations (logger, step, model, img)

Function that uses a Pytorch forward hook to log properties of activations for debugging purposes.

Exported source
def log_activations(logger, step, model, img):
    "Function that uses a Pytorch forward hook to log properties of activations for debugging purposes."
    def debugging_hook(module, inp, out):            
        if 'relu' in module.name:
            acts = out.detach()
            
            min_count = (acts < 1e-1).sum((0, 2, 3))
            shape = acts.shape
            total_count = shape[0]*shape[2]*shape[3]
            rate = min_count/total_count
            logger.log({"max_dead_rate/" + str(module.name): rate.max()}, step=step)
            
            #acts_flat = acts.cpu().view(-1,)
            #acts_hist = np.histogram(acts_flat.log1p(), 100)
            #logger.log({"relu_hist/" + str(module.name): wandb.Histogram(np_histogram=acts_hist)}, step=step)
            
    with register_module_forward_hook(debugging_hook):
        model(img)

source

log_imgs

 log_imgs (pred, label, mask, viz_tfm, logger, stage, step)

TODO: documentation missing

Exported source
def log_imgs(pred, label, mask, viz_tfm, logger, stage, step):
    "TODO: documentation missing"
    pred_np = pred[0].detach().cpu().numpy().argmax(0)
    label_np = label[0].detach().cpu().numpy()
    mask_np = mask[0].detach().cpu().numpy()
    pred_np[pred_np == label_np] = 0
    _, pred_img, _ = viz_tfm(None, pred_np, mask_np)
    _, label_img, _ = viz_tfm(None, label_np, mask_np)
    img_cmp = np.concatenate((pred_img, label_img), axis=0)
    img_cmp = wandb.Image(img_cmp)
    logger.log({f"{stage}_examples": img_cmp}, step=step)

TODO: needs proper documentation with code examples for log_activations and log_imgs functions.


source

SemanticSegmentationTask

 SemanticSegmentationTask (model, loss_fn, viz_tfm, total_steps,
                           lr=0.0005)

Lightning Module to standardize experiments with semantic segmentation tasks.

Exported source
class SemanticSegmentationTask(LightningModule):
    "Lightning Module to standardize experiments with semantic segmentation tasks."
    def __init__(self, model, loss_fn, viz_tfm, total_steps, lr=5e-4):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.viz_tfm = viz_tfm
        self.lr = lr
        self.total_steps = total_steps
        self.train_accuracy = Accuracy(task="multiclass", num_classes=model.n_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=model.n_classes)
        self.automatic_optimization = False

        self.step_idx = 0
        
        for n, m in self.model.named_modules():
            assert not hasattr(m, 'name')
            m.name = n
        
    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=self.lr, eps=1e-5)
        lr_scheduler = OneCycleLR(optimizer, self.lr, self.total_steps)
        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
    
    def training_step(self, batch, batch_idx):
        stage = 'train'
        logger = self.logger.experiment
        
        loss, pred, label, mask = self.step(batch, batch_idx, stage, self.train_accuracy)
        if self.step_idx % int(0.01*self.total_steps) == 0:
            log_activations(logger, self.step_idx, self.model, batch[0])
        if self.step_idx % int(0.25*self.total_steps) == 0:
            log_imgs(pred, label, mask, self.viz_tfm, logger, stage, self.step_idx)
        self.manual_optimization(loss)
        self.step_idx += 1
    
    def on_train_epoch_end(self):
        self.log('train_acc_epoch', self.train_accuracy)

    def validation_step(self, batch, batch_idx):
        stage = 'val'
        logger = self.logger.experiment
        
        _, pred, label, mask = self.step(batch, batch_idx, stage, self.val_accuracy)
        if batch_idx == 0:
            log_imgs(pred, label, mask, self.viz_tfm, logger, stage, self.step_idx)
    
    def step(self, batch, batch_idx, stage, metric):
        img, label, mask = batch
        label[~mask] = 0
        
        pred = self.model(img)
        
        loss = self.loss_fn(pred, label)
        loss = loss[mask]
        loss = loss.mean()

        pred_f = torch.permute(pred, (0, 2, 3, 1)) # N,C,H,W -> N,H,W,C
        pred_f = torch.flatten(pred_f, 0, -2)      # N,H,W,C -> N*H*W,C
        mask_f = torch.flatten(mask)               # N,H,W   -> N*H*W
        pred_m = pred_f[mask_f, :]
        label_m = label[mask]
        metric(pred_m, label_m)
        
        self.log(f"{stage}_acc_step", metric)
        self.log(f"{stage}_loss_step", loss.log10())

        return loss, pred, label, mask
    
    def manual_optimization(self, loss):
        optimizer = self.optimizers()
        optimizer.zero_grad()
        self.manual_backward(loss)
        
        p_old = {}
        for n, p in self.model.named_parameters():
            p_old[n] = p.detach().clone()
        
        optimizer.step()
        
        for n, p in self.model.named_parameters():
            optim_step = p.detach() - p_old[n]
            
            #log_rate = optim_step.abs().log1p() - p_old[n].abs().log1p()
            #log_rate_hist = np.histogram(log_rate.cpu().view(-1), 100)
            #self.logger.experiment.log({"log_update_rate_hist/" + str(n): wandb.Histogram(np_histogram=log_rate_hist)}, step=self.step_idx)
            
            self.logger.experiment.log({"ud/" + str(n): ((optim_step.std())/(p_old[n].std() + 1e-5)).log10()}, step=self.step_idx)
        
        lr_scheduler = self.lr_schedulers()
        lr_scheduler.step()