biasutti2019riu

Summarizing quotes from the paper:

1. Introduction

  • “we propose RIU-Net (for Range-Image U-Net), the adaptation of U-Net […] to the semantic segmentation of 3D LiDAR point clouds.” ### 2. Related works
  • “Recently, Wu et al. proposed SqueezeSeg, a novel approach for the semantic segmentation of a LiDAR point cloud represented as a spherical range-image.” ### 3. Methodology
  • “The method consists in feeding the U-Net architectures with 2-channels [range-images] encoding range and elevation.” #### 3.A Input of the Network
  • “we use a range-image named u of 512 × 64px with two channels: the depth towards the sensor and the elevation.” #### 3.B Architecture Figure from the original paper (Fig.2).
  • “the encoder consists in the repeated application of two 3×3 convolutions followed by a rectified linear unit (ReLU) and a 2×2 max-pooling layer that downsamples the input by a factor 2. Each time a downsampling is done, the number of features is doubled.”
  • “The decoder consists in upsampling blocks where the input is upsampled using 2 × 2 up-convolutions. Then, concatenation is done between the upsampled feature map and the corresponding feature map of the encoder. After that, two 3 × 3 convolutions are applied followed by a ReLU. This block is repeated until the output of the network matches the dimension of the input.”
  • “the last layer consists in a 1x1 convolution that outputs as many features as the wanted number of possible labels K 1-hot encoded.” #### 3.C Loss function
  • “is defined as the cross-entropy of the softmax of the output of the network.”
  • “[in the equation shown in the paper] m(x) > 0 are the valid pixels and w(x) is a weighting function introduced to give more importance to pixels that are close to a separation between two labels, as defined in U-Net.” #### 3.D Training
  • “We train the network with the Adam stochastic gradient optimizer and a learning rate set to 0.001. We also use batch normalization with a momentum of 0.99 to ensure good convergence of the model. Finally, the batch size is set to 8 and the training is stopped after 10 epochs.” ### 4. Experiments
  • “we follow the experimental setup of the SqueezeSeg approach […] which contains 8057 samples for training and 2791 for validation.”
  • “we use the intersection-over-union metric”
  • “we advocate that the proposed model can operate with a frame-rate of 90 frames per second on a single GPU

RIUNet architecture


source

Block

 Block (in_channels, out_channels)

Convolutional block repeatedly used in the RIU-Net encoder and decoder.

Exported source
class Block(Sequential):
    "Convolutional block repeatedly used in the RIU-Net encoder and decoder."
    def __init__(self, in_channels, out_channels):
        super().__init__(OrderedDict([
            (f'conv1', Conv2d(in_channels, out_channels, 3, 1, 1, bias=False, padding_mode='circular')),
            (f'bn1', BatchNorm2d(out_channels, momentum=0.01)),
            (f'relu1', ReLU()),
            (f'conv2', Conv2d(out_channels, out_channels, 3, 1, 1, bias=False, padding_mode='circular')),
            (f'bn2', BatchNorm2d(out_channels, momentum=0.01)),
            (f'relu2', ReLU()),
        ]))
        self.init_params()
    
    def init_params(self):
        for n, p in self.named_parameters():
            if re.search('conv\d\.weight', n):
                kaiming_normal_(p, nonlinearity='relu')

It implements the following architecture:

flowchart LR
  A(("
  Input
  (bs, in_c, h, w)")) --> B["
  Conv(3x3)
  in_c -> out_c"]
  B --> C["BatchNorm2d"]
  C --> D["ReLU"]
  D --> E["
  Conv(3x3)
  out_c -> out_c"]
  E --> F["BatchNorm2d"]
  F --> G["ReLU"]
  G --> H(("
  Output
  (bs, out_c, h, w)"))

Here is an example on how to use it:

bs, in_c, out_c, h, w = 1, 5, 64, 64, 512
inp = torch.randn(bs, in_c, h, w)

b = Block(in_c, out_c)
outp = b(inp)
assert outp.shape == (bs, out_c, h, w)
print(outp.shape, f'== ({bs}, {out_c}, {h}, {w})')
torch.Size([1, 64, 64, 512]) == (1, 64, 64, 512)

It initializes the weights from the conv layers following the kaiming_normal_ algorithm in fan_in mode as described in U-Net (page 5).

from matplotlib import pyplot as plt
import matplotlib.colors as mcolors
colors_list = list(mcolors.TABLEAU_COLORS)

def plot_param_dists(net, param_re_pattern, nonlinearity_gain):
    color_idx = 0
    for n, p in net.named_parameters():
        if re.search(param_re_pattern, n):
            x_range = 1.1*p.data.max()
            ## kaiming normal dist
            fan_in = p.shape[1]*p.shape[2]*p.shape[3]
            mu, sigma = 0., np.sqrt(nonlinearity_gain/fan_in)
            x = np.linspace(-x_range, x_range, 100)
            y = ((1./(np.sqrt(2*np.pi)*sigma))*np.exp(-0.5*((1./sigma)*(x - mu))**2))
            plt.plot(x, y, '--', color=colors_list[color_idx], label='Expected '+n)
            ## sampled weight dist
            plt.hist(p.view(-1).data, 30, density=True, alpha=0.5, color=colors_list[color_idx], label='Actual '+n)
            color_idx += 1
    plt.legend();
plot_param_dists(b, 'conv\d\.weight', 2.)

for n, p in b.named_parameters():
    if re.search('conv\d\.weight', n):
        fan_in = p.shape[1]*p.shape[2]*p.shape[3]
        mu, sigma = 0., np.sqrt(2./fan_in)
        p_data = p.view(-1).data
        assert abs(mu - p_data.mean()) < 1e-2
        assert abs(sigma - p_data.std()) < 1e-2

source

Encoder

 Encoder (channels=(5, 64, 128, 256, 512, 1024))

RIU-Net encoder architecture.

Exported source
class Encoder(Module):
    "RIU-Net encoder architecture."
    def __init__(self, channels=(5, 64, 128, 256, 512, 1024)):
        super().__init__()
        self.blocks = ModuleList(
            [Block(channels[i], channels[i+1]) for i in range(len(channels)-1)]
        )
    
    def forward(self, x):
        enc_features = []
        for block in self.blocks:
            x = block(x)
            enc_features.append(x)
            x = F.max_pool2d(x, 2)
        return enc_features

It implements the following architecture:

flowchart LR
  A(("
  Input
  (bs, 5, h, w)")) --> B["
  Block
  5 -> 64"]
  B --> C["MaxPool(2x2)"]
  C --> D["
  Block
  64 -> 128"]
  D --> E["MaxPool(2x2)"]
  E --> F["
  Block
  128 -> 256"]
  F --> G["MaxPool(2x2)"]
  G --> H["
  Block
  256 -> 512"]
  H --> I["MaxPool(2x2)"]
  I --> J["
  Block
  512 -> 1024"]
  B --> L(("
  Output
  [(bs, 64, h, w),
  (bs, 128, h/2, w/2),
  (bs, 256, h/4, w/4),
  (bs, 512, h/8, w/8),
  (bs, 1024, h/16, w/16)]"))
  D --> L
  F --> L
  H --> L
  J --> L

Here is an example on how to use it:

enc = Encoder()
outp = enc(inp)
[o.shape for o in outp]
[torch.Size([1, 64, 64, 512]),
 torch.Size([1, 128, 32, 256]),
 torch.Size([1, 256, 16, 128]),
 torch.Size([1, 512, 8, 64]),
 torch.Size([1, 1024, 4, 32])]

For the decoder, the paper mentions the application of “up-convolutions”, which were first defined in U-Net as:

“an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”)”

Our implementation changes the 2x2 convolution to a 3x3 one to avoid croping the feature maps to handle shape mismatches with the encoder skip connections.


source

UpConv

 UpConv (in_channels, out_channels)

Up-convolution operation adapted from U-Net.

Exported source
class UpConv(Sequential):
    "Up-convolution operation adapted from [U-Net](https://arxiv.org/abs/1505.04597)."
    def __init__(self, in_channels, out_channels):
        super().__init__(OrderedDict([
            (f'upsample', Upsample(scale_factor=2)),
            (f'conv', Conv2d(in_channels, out_channels, 3, 1, 1, bias=False, padding_mode='circular')),
            (f'bn', BatchNorm2d(out_channels, momentum=0.01)),
        ]))
        self.init_params()
    
    def init_params(self):
        for n, p in self.named_parameters():
            if re.search('conv.weight', n):
                kaiming_normal_(p, nonlinearity='linear')

Here is an example on how to use it:

upc = UpConv(1024, 512)
outp = enc(inp)
print(f'before: {outp[-1].shape}')
outp = upc(outp[-1])
print(f'after: {outp.shape}')
before: torch.Size([1, 1024, 4, 32])
after: torch.Size([1, 512, 8, 64])

source

Decoder

 Decoder (channels=(1024, 512, 256, 128, 64))

RIU-Net decoder architecture.

Exported source
class Decoder(Module):
    "RIU-Net decoder architecture."
    def __init__(self, channels=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.upconvs = ModuleList(
            [UpConv(channels[i], channels[i+1]) for i in range(len(channels)-1)]
        )
        self.blocks = ModuleList(
            [Block(channels[i], channels[i+1]) for i in range(len(channels)-1)]
        )
    
    def forward(self, enc_features):
        x = enc_features[-1]
        for i, (upconv, block) in enumerate(zip(self.upconvs, self.blocks)):
            x = upconv(x)
            x = torch.cat([x, enc_features[-(i+2)]], dim=1)
            x = block(x)
        return x

It implements the following architecture:

flowchart LR
  A(("
  Input
  [(bs, 1024, h/16, w/16),
  (bs, 512, h/8, w/8),
  (bs, 256, h/4, w/4),
  (bs, 128, h/2, w/2),
  (bs, 64, h, w)]")) --> B["
  UpConv
  1024 -> 512"]
  B --> C["
  Block
  concat(512,512) -> 512"]
  A --> C
  C --> D["
  UpConv
  512 -> 256"]
  D --> E["
  Block
  concat(256,256) -> 256"]
  A --> E
  E --> F["
  UpConv
  256 -> 128"]
  F --> G["
  Block
  concat(128,128) -> 128"]
  A --> G
  G --> H["
  UpConv
  128 -> 64"]
  H --> I["
  Block
  concat(64,64) -> 64"]
  A --> I
  I --> J(("
  Output
  (bs, 64, h, w)"))

Here is an example on how to use it:

dec = Decoder()
outp = enc(inp)
fts = dec(outp)
assert fts.shape == (bs, out_c, h, w)
print(fts.shape, f'== ({bs}, {out_c}, {h}, {w})')
torch.Size([1, 64, 64, 512]) == (1, 64, 64, 512)

It initializes the weights from the upconv layers following the kaiming_normal_ algorithm in fan_in mode and nonlinearity set as ‘linear’.

plot_param_dists(dec, 'upconvs\.\d\.conv\.weight', 1.)

for n, p in dec.named_parameters():
    if re.search('upconvs\.\d\.conv\.weight', n):
        fan_in = p.shape[1]*p.shape[2]*p.shape[3]
        mu, sigma = 0., np.sqrt(1./fan_in)
        p_data = p.view(-1).data
        assert abs(mu - p_data.mean()) < 1e-3
        assert abs(sigma - p_data.std()) < 1e-3

source

RIUNet

 RIUNet (in_channels=5, hidden_channels=(64, 128, 256, 512, 1024),
         n_classes=20)

RIU-Net complete architecture.

Exported source
class RIUNet(Module):
    "RIU-Net complete architecture."
    def __init__(self, in_channels=5, hidden_channels=(64, 128, 256, 512, 1024), n_classes=20):
        super().__init__()
        self.n_classes = n_classes
        self.input_norm = BatchNorm2d(in_channels, affine=False, momentum=None)
        self.backbone = Sequential(OrderedDict([
            (f'enc', Encoder((in_channels, *hidden_channels))),
            (f'dec', Decoder(hidden_channels[::-1]))
        ]))
        self.head = Conv2d(hidden_channels[0], n_classes, 1)
        self.init_params()

    def init_params(self):
        for n, p in self.named_parameters():
            if re.search('head\.weight', n):
                normal_(p, std=1e-2)
            if re.search('head\.bias', n):
                zeros_(p)
    
    def forward(self, x):
        x = self.input_norm(x)
        features = self.backbone(x)
        prediction = self.head(features)
        
        return prediction

It implements the following architecture:

flowchart LR
  A(("
  Input
  (bs, 5, h, w)")) --> B["Encoder"]
  B --> C["Decoder"]
  C --> D["
  Conv(1x1)
  64 -> 20"]
  D --> E(("
  Output
  (bs, 20, h, w)"))

Here is an example on how to use it:

n_classes=20
model = RIUNet()
logits = model(inp)
assert logits.shape == (bs, n_classes, h, w)
print(logits.shape, f'== ({bs}, {n_classes}, {h}, {w})')
torch.Size([1, 20, 64, 512]) == (1, 20, 64, 512)

It initializes the weights from the classification head from a normal distribution with a standard deviation of 1e-2.

plot_param_dists(model, 'head\.weight', 0.0064)

for n, p in model.named_parameters():
    if re.search('head\.weight', n):
        fan_in = p.shape[1]*p.shape[2]*p.shape[3]
        mu, sigma = 0., 1e-2
        p_data = p.view(-1).data
        assert abs(mu - p_data.mean()) < 1e-2
        assert abs(sigma - p_data.std()) < 1e-2

Loss function

LightningModule for standard experiments


source

log_activations

 log_activations (logger, step, model, img)

Function that uses a Pytorch forward hook to log properties of activations for debugging purposes.

Exported source
def log_activations(logger, step, model, img):
    "Function that uses a Pytorch forward hook to log properties of activations for debugging purposes."
    def debugging_hook(module, inp, out):            
        if 'relu' in module.name:
            acts = out.detach()
            
            min_count = (acts < 1e-1).sum((0, 2, 3))
            shape = acts.shape
            total_count = shape[0]*shape[2]*shape[3]
            rate = min_count/total_count
            logger.log({"max_dead_rate/" + str(module.name): rate.max()}, step=step)
            
            #acts_flat = acts.cpu().view(-1,)
            #acts_hist = np.histogram(acts_flat.log1p(), 100)
            #logger.log({"relu_hist/" + str(module.name): wandb.Histogram(np_histogram=acts_hist)}, step=step)
            
    with register_module_forward_hook(debugging_hook):
        model(img)

source

log_imgs

 log_imgs (pred, label, mask, viz_tfm, logger, stage, step)

TODO: documentation missing

Exported source
def log_imgs(pred, label, mask, viz_tfm, logger, stage, step):
    "TODO: documentation missing"
    pred_np = pred[0].detach().cpu().numpy().argmax(0)
    label_np = label[0].detach().cpu().numpy()
    mask_np = mask[0].detach().cpu().numpy()
    pred_np[pred_np == label_np] = 0
    _, pred_img, _ = viz_tfm(None, pred_np, mask_np)
    _, label_img, _ = viz_tfm(None, label_np, mask_np)
    img_cmp = np.concatenate((pred_img, label_img), axis=0)
    img_cmp = wandb.Image(img_cmp)
    logger.log({f"{stage}_examples": img_cmp}, step=step)

TODO: needs proper documentation with code examples for log_activations and log_imgs functions.


source

SemanticSegmentationTask

 SemanticSegmentationTask (model, loss_fn, viz_tfm, total_steps,
                           lr=0.0005)

Lightning Module to standardize experiments with semantic segmentation tasks.

Exported source
class SemanticSegmentationTask(LightningModule):
    "Lightning Module to standardize experiments with semantic segmentation tasks."
    def __init__(self, model, loss_fn, viz_tfm, total_steps, lr=5e-4):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.viz_tfm = viz_tfm
        self.lr = lr
        self.total_steps = total_steps
        self.train_accuracy = Accuracy(task="multiclass", num_classes=model.n_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=model.n_classes)
        self.automatic_optimization = False

        self.step_idx = 0
        
        for n, m in self.model.named_modules():
            assert not hasattr(m, 'name')
            m.name = n
        
    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=self.lr, eps=1e-5)
        lr_scheduler = OneCycleLR(optimizer, self.lr, self.total_steps)
        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
    
    def training_step(self, batch, batch_idx):
        stage = 'train'
        logger = self.logger.experiment
        
        loss, pred, label, mask = self.step(batch, batch_idx, stage, self.train_accuracy)
        if self.step_idx % int(0.01*self.total_steps) == 0:
            log_activations(logger, self.step_idx, self.model, batch[0])
        if self.step_idx % int(0.25*self.total_steps) == 0:
            log_imgs(pred, label, mask, self.viz_tfm, logger, stage, self.step_idx)
        self.manual_optimization(loss)
        self.step_idx += 1
    
    def on_train_epoch_end(self):
        self.log('train_acc_epoch', self.train_accuracy)

    def validation_step(self, batch, batch_idx):
        stage = 'val'
        logger = self.logger.experiment
        
        _, pred, label, mask = self.step(batch, batch_idx, stage, self.val_accuracy)
        if batch_idx == 0:
            log_imgs(pred, label, mask, self.viz_tfm, logger, stage, self.step_idx)
    
    def step(self, batch, batch_idx, stage, metric):
        img, label, mask = batch
        label[~mask] = 0
        
        pred = self.model(img)
        
        loss = self.loss_fn(pred, label)
        loss = loss[mask]
        loss = loss.mean()

        pred_f = torch.permute(pred, (0, 2, 3, 1)) # N,C,H,W -> N,H,W,C
        pred_f = torch.flatten(pred_f, 0, -2)      # N,H,W,C -> N*H*W,C
        mask_f = torch.flatten(mask)               # N,H,W   -> N*H*W
        pred_m = pred_f[mask_f, :]
        label_m = label[mask]
        metric(pred_m, label_m)
        
        self.log(f"{stage}_acc_step", metric)
        self.log(f"{stage}_loss_step", loss.log10())

        return loss, pred, label, mask
    
    def manual_optimization(self, loss):
        optimizer = self.optimizers()
        optimizer.zero_grad()
        self.manual_backward(loss)
        
        p_old = {}
        for n, p in self.model.named_parameters():
            p_old[n] = p.detach().clone()
        
        optimizer.step()
        
        for n, p in self.model.named_parameters():
            optim_step = p.detach() - p_old[n]
            
            #log_rate = optim_step.abs().log1p() - p_old[n].abs().log1p()
            #log_rate_hist = np.histogram(log_rate.cpu().view(-1), 100)
            #self.logger.experiment.log({"log_update_rate_hist/" + str(n): wandb.Histogram(np_histogram=log_rate_hist)}, step=self.step_idx)
            
            self.logger.experiment.log({"ud/" + str(n): ((optim_step.std())/(p_old[n].std() + 1e-5)).log10()}, step=self.step_idx)
        
        lr_scheduler = self.lr_schedulers()
        lr_scheduler.step()