19

Deep Learning in Healthcare — X-Ray Imaging (Part 4-The Class Imbalance problem)

 3 years ago
source link: https://towardsdatascience.com/deep-learning-in-healthcare-x-ray-imaging-part-4-the-class-imbalance-problem-364eff4d47bb?gi=50df9347c9f8
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Deep Learning in Healthcare — X-Ray Imaging (Part 4-The Class Imbalance problem)

This is part 4 of the application of Deep learning on X-Ray imaging. Here the focus will be on various ways to tackle the class imbalance problem.

2miIFre.jpg!web

Jul 5 ·11min read

As we saw in the previous part — Part 3 — ( https://towardsdatascience.com/deep-learning-in-healthcare-x-ray-imaging-part-3-analyzing-images-using-python-915a98fbf14c ), the chest x-ray dataset has an imbalance of images. This is the bar chart of the images per class that we had seen in the previous part.

jM7zEz3.jpg!web

Figure 1. The imbalance in the images of the various classes (image by Author)

In medical imaging datasets, this is a very common problem. Since most often the data is collected from various different sources, and not all diseases are as prevalent as others, so the datasets are imbalanced more often than not.

So what is the problem if we train the neural network on an imbalanced dataset? The answer is that the network tends to learn more from the classes with more images than the ones with fewer images. That is, in this case, the model might predict more images to be ‘Bacterial Pneumonia’, even though the images might be from the other two classes, and that is an undesirable outcome when dealing with medical images.

Also, it should be noted, while dealing with medical images, the final accuracy (both train accuracy or validation accuracy) of the model is not the right parameter to base the model’s performance on. Because, even if the model is performing poorly on a particular class, but performing well on the class with maximum images, the accuracy would still be high. In reality, we want the model to perform well in all the classes. Thus, there are other parameters, such as sensitivity(Recall/True Positive Rate (TPR)), specificity(True Negative Rate(TNR), Precision or Positive Predicted Value (PPV), and F-scores, which should be considered to analyze the performance of a trained model. We will discuss these in detail in a later part, where we discuss the confusion matrix.

It is also a must to maintain a separate set of images, on which the model is neither trained nor validated, so as to check how the model performs on images that it has never seen before. This is also a compulsory step to analyze the performance of the model.

Various ways to tackle class imbalance:

There are various ways to tackle the class imbalance problem. The best method is to collect more images for the minority classes. But that is not possible in certain situations. In that case, commonly these 3 methods can be beneficial: a. Weighted Loss b. Undersampling c. Oversampling

We will go through each of these methods in details:

  1. Updating the loss function — Weighted Loss

Suppose we are using Binary Cross-Entropy loss function . The loss function looks like this -

L(X,y) = - log P(Y =1 |X) if y =1 and -log P(Y=0 |X) if y=0

This measures the output of a classification model whose output is between zero and one. (This loss function only works if we are doing a binary classification problem. For multiple classes, we use Categorical Cross-Entropy loss or Sparse Categorical Cross-Entropy loss. We will discuss basic loss functions in a later part).

Example — If the label of an image is 1, and the neural network algorithm predicts the probability that the label is 1 is 0.2.

Let's apply the loss function to compute the loss for this example. Notice that we are interested in the label 1. So, we are going to use the first part of the loss function L. The loss L is going to be -

L =-log 0.2 = 0.70

So this is the loss the algorithm gets for this example.

For another image whose label is 0, if the algorithm predicts that the probability of the image to be label 0 is 0.7, then we use the second part of the loss function, but we cannot really use it directly. Rather we use a different approach. We know the maximum probability can be 1, so we calculate the probability of the label is 1.

In this case, L = -log (1–0.7) =-log (0.3) = 0.52

Now let's look at multiple examples, with class imbalance.

FNrMf2R.jpg!web

Figure 1. Class Imbalance, Probability and Calculated loss (Source: Image created by author)

In Figure 1, we see there are a total of 10 images, but 8 of those belong to class label 1, and only two belong to class label 0. Hence this is a classic class imbalance problem. Assuming all were predicted with a probability 0.5,

Loss L for label 1 = -log(0.5) = 0.3,

Loss L for label 0 = -log(1–0.5) = -log(0.5) = 0.3

So, the total loss for label 1 = 0.3 x 8 = 2.4

whereas, the total loss for label 0 = 0.3 x 2 = 0.6

So, most of the contributions to the loss is coming from class with label 1. So the algorithm when updating weights will prefer to update weights of label 1 images much more, than weights of images with label 0. This does not produce a very good classifier, and this is the Class Imbalance Problem .

The solution to the class imbalance problem is to modify the loss function, to weigh the 1 and 0 classes differently .

w1 is the weights we assign to label 1 examples, and w0 to label 0 examples. New Loss Function,

L = w1 x -log(Y =1 |X) if y =1 ,and,

L = w0 x -log P(Y=0 |X) if y=0

We want to give more weights to classes with fewer images than the classes with more images. So in this case we give class 1 which as 8 examples a weight of 2/10 = 0.2, and class 0 which as 2 examples a weight of 8/10 = 0.8.

Generally, the weights are calculated using the formula below,

w1 = number of images with label 0/total number images = 2/10

w0 = number of images with label 1/total number of images = 8/10

Below is the updated table of loss, by using weighted loss.

MV3AJbJ.jpg!web

Figure 2. Updated Weighted Loss (Source: Image created by author)

So for the new calculations, we just multiply the losses with the respective weights of the classes. Now if we calculate the total loss,

The total loss for label 1 = 0.06 x 8 = 0.48

The total loss for label 0 = 0.24 x 2 = 0.48

Now both the classes have the same total loss. So even though both classes have a different number of images, the algorithm will now treat both the classes equally, and the classifier will correctly classify images of classes with even very few images.

2. Downsampling

Downsampling is the process of removing images from the class with most images to make it comparable with the classes with lower images.

For example, in the pneumonia classification problem, we see that there are 2530 bacterial pneumonia images compared to 1341 normal and 1337 viral pneumonia images. So we can just remove around 1200 images from the bacterial pneumonia class so that all the classes have a similar number of images.

This is possible for datasets that have a lot of images belonging to each class, and removing a few images will not hurt the performance of the neural network.

3. Oversampling

Oversampling is the process of adding more images to minority classes so as to make the number of images in minority classes similar to those in the majority classes.

This can be done by simply duplicating the images in the minority classes. Directly copying the same image twice, can cause the network to overfit. So to reduce overfitting we can use some artificial data augmentation to create more images for the minority classes. (This too does cause some overfitting, but is a much better technique than directly copying the original images two-three times)

This is the technique that we used in the pneumonia classification task, and the network worked quite well.

Next, we look at the python code to generate an artificial dataset.

import numpy as np
import pandas as pd
import cv2 as cv
import matplotlib.pyplot as plt
import os
import random

from sklearn.model_selection import train_test_split

We have seen all the libraries before, except sklearn.

sklearn — Scikit-learn (also known as sklearn) is a machine learning library for python. It contains all famous machine learning algorithms such as classification, regression, support vector machines, random forests, etc. It is also a very important library for machine learning data pre-processing.

image_size = 256

labels = ['1_NORMAL', '2_BACTERIA','3_VIRUS']

def create_training_data(paths):

images = []

for label in labels:
dir = os.path.join(paths,label)
class_num = labels.index(label)

for image in os.listdir(dir):
image_read = cv.imread(os.path.join(dir,image))
image_resized = cv.resize(image_read,(image_size,image_size),cv.IMREAD_GRAYSCALE)
images.append([image_resized,class_num])

return np.array(images)
train = create_training_data('D:/Kaggle datasets/chest_xray_tf/train')X = []
y = []

for feature, label in train:
X.append(feature)
y.append(label)

X= np.array(X)
y = np.array(y)
y = np.expand_dims(y, axis=1)

The above code calls the training dataset and loads the images in X and the labels in y. Details already mentioned in Part 3 — ( https://towardsdatascience.com/deep-learning-in-healthcare-x-ray-imaging-part-3-analyzing-images-using-python-915a98fbf14c ).

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,random_state = 32, stratify=y)

Since we have only train and validation data, and no test data, so we create the test data using train_test_split from sklearn. It is used to split the entire data into train and test images and labels. We assign 20% of the entire data to test set, and hence set ‘test_size = 0.2’, and random_state shuffles the data the first time but then keeps them constant from the next run and is used to not shuffle the images every time we run train_test_split, stratify is important to be mentioned here, as the data is imbalanced, as stratify makes sure that there is an equal split of images of each class in the train and test sets.

Important Note — Oversampling should be done on train data, and not on test data as if test data contains artificially generated images, the classifier results we will see would not be a proper interpretation of how much the network actually learned. So, the better method is to first split the train and test data and then oversample only the training data.

# checking the number of images of each class

a = 0
b = 0
c = 0

for label in y_train:
if label == 0:
a += 1
if label == 1:
b += 1
if label == 2:
c += 1

print (f'Number of Normal images = {a}')
print (f'Number of Bacteria images = {b}')
print (f'Number of Virus images = {c}')

# plotting the data

xe = [i for i, _ in enumerate(labels)]

numbers = [a,b,c]
plt.bar(xe,numbers,color = 'green')
plt.xlabel("Labels")
plt.ylabel("No. of images")
plt.title("Images for each label")

plt.xticks(xe, labels)

plt.show()

output -

YZRvimZ.jpg!web

So now we see. the training set has 1226 normal images, 2184 bacterial pneumonia images, and 1154 viral pneumonia images.

#check the difference from the majority classdifference_normal = b-a
difference_virus = b-c

print(difference_normal)
print(difference_virus)

output —

958

1030

Solving the imbalance —

def rotate_images(image, scale =1.0, h=256, w = 256):

center = (h/2,w/2)

angle = random.randint(-25,25)
M = cv.getRotationMatrix2D(center, angle, scale)
rotated = cv.warpAffine(image, M, (h,w))
return rotated

def flip (image):

flipped = np.fliplr(image)
return flipped

def translation (image):

x= random.randint(-50,50)
y = random.randint(-50,50)
rows,cols,z = image.shape
M = np.float32([[1,0,x],[0,1,y]])
translate = cv.warpAffine(image,M,(cols,rows))

return translate

def blur (image):

x = random.randrange(1,5,2)
blur = cv.GaussianBlur(image,(x,x),cv.BORDER_DEFAULT)
return blur

We will be using 4 types of data augmentation methods, using the OpenCV library — 1. rotation- from -25 to +25 degrees at random, 2. flipping the images horizontally, 3. translation, with random settings both for the x and y-axis, 4. gaussian blurring at random.

For details on how to implement data augmentation using OpenCV please visit the following link — https://opencv.org

def apply_aug (image):

number = random.randint(1,4)

if number == 1:
image= rotate_images(image, scale =1.0, h=256, w = 256)

if number == 2:
image= flip(image)

if number ==3:
image= translation(image)

if number ==4:
image= blur(image)

return image

Next, we define another function, so that all the augmentations are applied completely randomly.

def oversample_images (difference_normal,difference_virus, X_train, y_train):

normal_counter = 0
virus_counter= 0
new_normal = []
new_virus = []
label_normal = []
label_virus = []

for i,item in enumerate (X_train):

if y_train[i] == 0 and normal_counter < difference_normal:

image = apply_aug(item)

normal_counter = normal_counter+1
label = 0

new_normal.append(image)
label_normal.append(label)


if y_train[i] == 2 and virus_counter < difference_virus:

image = apply_aug(item)

virus_counter = virus_counter+1
label =2

new_virus.append(image)
label_virus.append(label)


new_normal = np.array(new_normal)
label_normal = np.array(label_normal)
new_virus= np.array(new_virus)
label_virus = np.array(label_virus)

return new_normal, label_normal, new_virus, label_virus

This function, creates all the artificially augmented images for normal and viral pneumonia images, till they reach the difference in values from the total bacterial pneumonia images. It then returns the newly created normal and viral pneumonia images and labels.

n_images,n_labels,v_images,v_labels =oversample_images(difference_normal,difference_virus,X_train,y_train)print(n_images.shape)
print(n_labels.shape)
print(v_images.shape)
print(v_labels.shape)

output —

36bymaB.jpg!web

We see that as expected, 958 normal images have been created and 1030 viral pneumonia images have been created.

Let's visualize a few of the artificial normal images,

# Extract 9 random images
print('Display Random Images')

# Adjust the size of your images
plt.figure(figsize=(20,10))

for i in range(9):
num = random.randint(0,len(n_images)-1)
plt.subplot(3, 3, i + 1)

plt.imshow(n_images[num],cmap='gray')
plt.axis('off')

# Adjust subplot parameters to give specified padding
plt.tight_layout()

output -

3a2qiaA.jpg!web

Next, let’s visualize a few of the artificial viral pneumonia images,

# Displays 9 generated viral images 
# Extract 9 random images
print('Display Random Images')

# Adjust the size of your images
plt.figure(figsize=(20,10))

for i in range(9):
num = random.randint(0,len(v_images)-1)
plt.subplot(3, 3, i + 1)

plt.imshow(v_images[num],cmap='gray')
plt.axis('off')

# Adjust subplot parameters to give specified padding
plt.tight_layout()

output -

BrQF327.jpg!web

Each of those images generated above has some kind of augmentation — rotation, translation, flipping or blurring, all applied at random.

Next, we merge these artificial images and their labels with the original training dataset.

new_labels = np.append(n_labels,v_labels)
y_new_labels = np.expand_dims(new_labels, axis=1)
x_new_images = np.append(n_images,v_images,axis=0)

X_train1 = np.append(X_train,x_new_images,axis=0)
y_train1 = np.append(y_train,y_new_labels)

print(X_train1.shape)
print(y_train1.shape)

output —

2qIvEzJ.jpg!web

Now, the training dataset has 6552 images.

bacteria_new=0
virus_new=0
normal_new =0

for i in y_train1:

if i==0:
normal_new = normal_new+1
elif i==1 :
bacteria_new = bacteria_new+1
else:
virus_new=virus_new+1

print ('Number of Normal images =',normal_new)
print ('Number of Bacteria images = ',bacteria_new)
print ('Number of Virus images =',virus_new)

# plotting the data

xe = [i for i, _ in enumerate(labels)]

numbers = [normal_new, bacteria_new, virus_new]
plt.bar(xe,numbers,color = 'green')
plt.xlabel("Labels")
plt.ylabel("No. of images")
plt.title("Images for each label")

plt.xticks(xe, labels)

plt.show()

output —

ayIBZ33.jpg!web

So finally, we have a balance in the training dataset. We have 2184 images in all the three classes.

So this is how we solved the Class Imbalance Problem. Feel free to try other methods and compare them with the final results.

Now that the class imbalance problem is dealt with in the next part we will look into image normalization and data augmentation using Keras and TensorFlow.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK