Lei Luo Machine Learning Engineer

Implementing YOLO-V3 Using PyTorch


Go straight to the code!

Introduction

The You Only Look Once (YOLO) object detection system is developed by Joseph Redmon, Santosh Divvala, Ross Girshick and Ali Farhadi. Unlike many other object detection system, such as R-CNN, YOLO frames object detection as a regression problem to spatially separated bounding boxes and associated class probabilities. R-CNN consists of three stages when performing detection, which makes it hard to optimize. It first uses region proposal methods to first generate potential bounding boxes in an image, then it runs a classifier on these proposed boxes. After classification, post-processing is used to refine the bounding boxes, eliminate duplicate detections, and rescore the boxes based on other objects in the scene. YOLO, on the other hand, uses a single neural network predicts bounding boxes and class probabilities directly from full images in one evaluation, so it can be optimized end-to-end directly on detection performance.

The key features about YOLO are:

First, YOLO is very fast. Since the detection is framed as a regression problem, YOLO does not need a complex pipeline and it only uses a single network, so it can be optimized end-to-end directly on detection performance, reporting 45 fps without batch processing on a Titan X GPU.

Second, YOLO reasons globally about the image when making predictions. Unlike sliding window and region proposal-based techniques, YOLO sees the entire image during training and test time so it implicitly encodes contextual information about classes as well as their appearance. Compared to fast R-CNN, YOLO makes less than half the number of background errors.

How Bounding Boxes Work

YOLO uses features from the entire image to predict bounding boxes across all classes for an image simultaneously. It divides the input image into an $S\times S$ grid. If the center of an object falls into a grid cell, that grid cell is responsible for detecting that object. Each grid cell predicts B bounding boxes and confidence scores for those boxes. These confidence scores reflect how confident the model is that the box contains an object and also how accurate it thinks the box is that it predicts.


Confidence is formally defined as $Pr(Object)*IOU^{pred}_{truth}$. If no object exists in that cell, the confidence scores should be zero. Otherwise we want the confidence score to equal the intersection over union (IOU) between the predicted box and the ground truth. Each bounding box consists of 5 predictions: $x, y, w, h$, and confidence. The $(x, y)$ coordinates represent the center of the box relative to the bounds of the grid cell. The width and height are predicted relative to the whole image. Each grid cell also predicts $C$ conditional class probabilities, $Pr(Class |Object)$. These probabilities are conditioned on the grid cell containing an object. We only predict one set of class probabilities per grid cell, regardless of the number of boxes $B$. Therefore, for each bounding we have:

which gives us class-specific confidence scores for each box. These scores encode both the probability of that class appearing in the box and how well the predicted box fits the object.

How Anchor Boxes Work

YOLO predicts the coordinates of bounding boxes directly using fully connected layers on top of the convolutional feature extractor. As an improvement, YOLO V2 shares the same idea as Faster R-CNN, which predicts bounding boxes offsets using hand-picked priors instead of predicting coordinates directly. Predicting offsets instead of coordinates simplifies the problem and makes it easier for the network to learn.

Then, the problem to consider is what dimensions we should use for the priors. Instead of choosing priors by hand, we can run k-means clustering on the training set bounding boxes to automatically find good priors. The distance metric being used is: $d(box, centroid) = 1 − IOU(box, centroid)$, because we want to have higher IOU score. In YOLO V2, 5 clusters are used. In YOLO V3 9 clusters are used at 3 different scales.

The next problem the authors encountered is model instability because directly predicting offsets the location of anchor box would be unconstrained so they can end up at any point in the image regardless of what location predicted the box. What’s more, with random initialization the model takes a long time to stabilize to predicting sensible offsets. Instead of predicting offsets, the authors predict location coordinates relative to the location of the grid cell. This bounds the ground truth to fall between 0 and 1 by using a logistic activation to constrain the network’s predictions to fall in this range.

Network Design

The YOLO has 24 convolutional layers followed by 2 fully connected layers. YOLO V2 has 19 convolutional layers and 5 maxpooling layers. YOLO V3 has 53 convolutional layers. Their respective structures are as follows:

Loss Function

The loss function has multiple parts:
1) Bounding box coordinates error and dimension error that is represented using mean square error.
2) Objectness error which is confidence score of whether there is an object or not. When there is an object, we want the score equals to IOU, and when there is no object we want to socore to be zero. This is also mean square error.
3) Classification error, which uses cross entropy loss.

If the loss function weights localization error equally with classification error, it will not perfectly align with the goal of maximizing average precision and may not be ideal. In every image many grid cells do not contain any object. This pushes the “confidence” scores of those cells towards zero, often overpowering the gradient from cells that do contain objects. This can lead to model instability, causing training to diverge early on. To remedy this, we increase the loss from bounding box coordinate predictions and decrease the loss from confidence predictions for boxes that don’t contain objects. Two parameters are used: $\lambda_{coord}=5$ and $\lambda_{noobj}=0.5$. The loss function also equally weights errors in large boxes and small boxes. The error metric should reflect that small deviations in large boxes matter less than in small boxes. To partially address this the authors predict the square root of the bounding box width and height instead of the width and height directly. Put it all together, the loss function is formally proposed as:

Implement Building Blocks

In this section, I will talk about the key implementation points on YOLO V3.

The first module to look at is the structure of the neural network, which consists of several blocks. In the model config file we can see different blocks that builds up the network. We will look into them one by one.

  1. Convolutional layer
    In the model config file we can see some like this:
    [convolutional]
    batch_normalize=1
    filters=32
    size=3
    stride=1
    pad=1
    activation=leaky
    

    It means we will build a 2D convolutional layer with 64 filters, 3x3 kernel size, strides on both dimension of being 1, pad 1 on both dimensions, use leaky relu activation function, and add a batch normalization layer with 1 filter.
    When implementing, it can be expressed as:

    conv = nn.Conv2d(prev_filters, filters, kernel_size, stride, pad, bias = bias)
    module.add_module("conv_{0}".format(index), conv)
    if batch_normalize:
       bn = nn.BatchNorm2d(filters)
       module.add_module("batch_norm_{0}".format(index), bn)
    if activation == "leaky":
       activn = nn.LeakyReLU(0.1, inplace = True)
       module.add_module("leaky_{0}".format(index), activn)
    
  2. Shortcut layer
    It means the output of this layer is obtained by adding feature maps from the pervious layer and the 3rd layer backwards from the shortcut layer.
    [shortcut]
    from=-3
    activation=linear
    

    When implementing, it can be expressed as:

    elif block["type"] == "shortcut":
      layer_i = int(block["from"])
      x = layer_outputs[-1] + layer_outputs[layer_i]
    
  3. Route layer
    There are two types of route layers.
    [route]
    layers = -4
    [route]
    layers = -1, 61
    

    When layers are only followed by 1 number, it means outputting the feature maps of the layer indexed by the value. When followed by multiple values, it means outputting the concatenated feature maps of the layers indexed by it’s values.
    In YOLO V2 paper, which reads “The passthrough layer concatenates the higher resolution features with the low resolution features by stacking adjacent features into different channels instead of spatial locations, similar to the identity mappings in ResNet.”
    When implementing, it can be expressed as:

    elif block["type"] == "route":
      layer_i = [int(x) for x in block["layers"]]
      x = torch.cat([layer_outputs[i] for i in layer_i], 1)
    
  4. Upsample layer
    It means upsampling the previous layer by a factor of stride.
    [upsample]
    stride=2
    

    When implementing, it can be expressed as:

    elif (x["type"] == "upsample"):
      stride = int(x["stride"])
      upsample = nn.Upsample(scale_factor = 2, mode = "nearest")
      module.add_module("upsample_{}".format(index), upsample)
    
  5. YOLO layer
    This type of layer is for detecting objects. In YOLO V3 there are three of these layers and each of them is responsible for detecting objects at one scale. At each scale we will define 3 anchor boxes for each grid. In this example the mask is 0,1,2, meaning that we will use the first three anchor boxes.
    When implementing, it can be expressed as:
    [yolo]
    mask = 6,7,8
    anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
    classes=80
    num=9
    jitter=.3
    ignore_thresh = .7
    truth_thresh = 1
    random=1
    
    elif x["type"] == "yolo":
      mask = x["mask"].split(",")
      mask = [int(x) for x in mask]            
      anchors = x["anchors"].split(",")
      anchors = [int(a) for a in anchors]
      anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors),2)]
      anchors = [anchors[i] for i in mask]            
      detection = YOLOLayer(anchors)
      module.add_module("Detection_{}".format(index), detection)
    
  6. Empty layer
    There is a trick when working with router layer and shortcut layer. The trick is the use of empty layer, which acts as a dummy layer and is simply defined as:
    class EmptyLayer(nn.Module):
     def __init__(self):
         super(EmptyLayer, self).__init__()
    

    This is because whenever we need to perform a layer operation, such as addition or concatenation, we need the data type to be a pytorch layer, which subclass nn.Module. So, we would define such a layer and then perform operations. For example:

    elif x["type"] == "shortcut":
      from_ = int(x["from"])
      shortcut = EmptyLayer()
      module.add_module("shortcut_{}".format(index), shortcut)
    

    Now that we have defined all the building blocks, we can now wire them together. Refer createModules and forward for complete code.

Implement Losses

The key to successfully training the network is to calculate losses, which I will discuss in this section. The format of bounding boxes in the training data are represented as:
$class\space lable, x, y, width, height$
But as we discussed earlier:
“Instead of predicting offsets, the authors predict location coordinates relative to the location of the grid cell. This bounds the ground truth to fall between 0 and 1 by using a logistic activation to constrain the network’s predictions to fall in this range.”
So, we need to convert x, y relative to grid position. In the code, it is reflected like:

gx = targets[batch_idx,target_idx, 1] * grid_size
gy = targets[batch_idx,target_idx, 2] * grid_size
gw = targets[batch_idx,target_idx, 3] * grid_size
gh = targets[batch_idx,target_idx, 4] * grid_size
gi = int(gx)
gj = int(gy)

As we mentioned, at each scale we define three anchor boxes for each grid and we need to pick the anchor box that has the highest IOU with the target box.

anchor_iou = bboxIOU(gt_box, anchor_shapes, True)
best = np.argmax(anchor_iou)

After determining the best anchor box, we can calculate the target box dimension that is relative to grid rather than the entire input image as which may come in different dimensions.

tx[batch_idx, best, gj, gi] = gx - gi
ty[batch_idx, best, gj, gi] = gy - gj
tw[batch_idx, best, gj, gi] = math.log(gw / anchors[best][0] + 1e-16)
th[batch_idx, best, gj, gi] = math.log(gh / anchors[best][1] + 1e-16)

$tw$ and $th$ are expressed this way because in YOLO V2 paper:
$b_{w} = p_{w}e^{t_{w}}$
$b_{h} = p_{h}e^{t_{h}}$
We also need to output confidence and class label which is easy to determine. For more information, go to buildTargets in utils.py

Now that the predicted values and target values have the same format, we can then calculate the loss. As we discussed earlier:
The loss function has multiple parts:
1) Bounding box coordinates error and dimension error that is represented using mean square error.
2) Objectness error which is confidence score of whether there is an object or not. When there is an object, we want the score equals to IOU, and when there is no object we want to socore to be zero. This is also mean square error.
3) Classification error, which uses cross entropy loss.

In the code, the calculation of losses are written as:

loss_x = self.lambda_coord * self.mse_loss(x[mask], tx[mask])
loss_y = self.lambda_coord * self.mse_loss(y[mask], ty[mask])
loss_w = self.lambda_coord * self.mse_loss(w[mask], tw[mask])
loss_h = self.lambda_coord * self.mse_loss(h[mask], th[mask])

loss_conf = self.lambda_noobj * self.mse_loss(pred_conf[conf_mask_false], tconf[conf_mask_false]) + self.mse_loss(pred_conf[conf_mask_true], tconf[conf_mask_true])
loss_cls = (1 / batch_size) * self.ce_loss(pred_cls[mask], torch.argmax(tcls[mask], 1))
loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls

Detect Objects

After the model has been trained with enough epochs, we can use the model to detect object.
There is one thing that needs special attention, which is non max suppression. We need that because in cases where the model predicts multiple bounding boxes for an object, then we need to eliminate duplicate detections. This forces to only output one set of prediction, which can be used to draw bounding box on objects. Click nonMaxSuppression for more details.

Detect Handsup

For experimenting purposes, I trained a YOLO V3 model for detecting people who raise their handsup in a classroom or a conference setting and automatically counts the total handsup. The process is detailed as:

  1. Crawle image online
    I used a tool called google-image-download that can automatic download images from google given a search keyword. In my case, I used “students raise hands in classroom”. I ended up downloading about 500 images and only used about 300 images as some images are cartoon and some images have less than 3 channels.

  2. Annotate image
    I used sloth to annotate all the images.
    The x, y, width, height attributes have to be relative to the dimensions of the image, so I wrote a script to convert absolute value to relative values. One example is shown below. The first column is class label. Since I was only interested in ‘handsup’, so I only tagged one class label. The rest are x, y, width, and height.
    0 0.6012591963974298 0.08228222014457276 0.02421177972070192 0.07377026633651351 
    0 0.7071857326755004 0.14612187370501714 0.034300021270994296 0.0666769714964641 
    0 0.8645623008600626 0.18016968893725413 0.03934414204614048 0.07235160736850361 
    0 0.8201740380387758 0.23549738868963926 0.03732649373608208 0.07235160736850363 
    0 0.3581325750353818 0.17591371203322453 0.02622942803076032 0.06951428943248386 
    0 0.23404720396678474 0.11065539950477025 0.0322823729609358 0.06242099459243451 
    0 0.0968471188828075 0.31494229089819237 0.027238252185789614 0.06809563046447402 
    0 0.16342951311473766 0.3262915626422714 0.04136179035619904 0.0950501508566616 
    0 0.28650606002830553 0.23975336559366894 0.035308845426023555 0.10923674053676043
    
  3. Calculate anchor box priors
    As we discussed earilier, we can use KMeans clustering method to obtain anchor priors, I used this code for that.

  4. Train and detect
    All the hyperparameters can be tuned, and after the model has been trained for 10000 epochs, I got a model can detect handsup with reasonably good results. The detection draws bounding boxes on objects and counts the total number of interests.

  5. Use webcam
    It can also use the webcam to detect objects in real time.



The implementation of the model using PyTorch is provided on my github repo.



Credit: Redmon, Joseph and Farhadi, Ali (2016). You Only Look Once: Unified, Real-Time Object Detection
Redmon, Joseph and Farhadi, Ali (2016). YOLO9000: Better, Faster, Stronger
Redmon, Joseph and Farhadi, Ali (2018).YOLOv3: An Incremental Improvement
https://github.com/marvis/pytorch-yolo2
https://github.com/pjreddie/darknet
https://github.com/eriklindernoren/PyTorch-YOLOv3
https://github.com/ayooshkathuria/pytorch-yolo-v3


Similar Posts

Comments