骨关节炎X线图像分类

19k 词

前言

对一张图片进行分类是CV中的传统任务,在这里我们在Digital Knee X-ray Images上,基于PyTorch使用CNN、Vision Transformer以及Swin Transformer进行这一实验,从X线图像中获取膝关节炎的KL分期。

CNN

简单的CNN

使用CNN,可以使用卷积逐层提取图像从基础形状到高级形态的特征,例如原始输入->竖线+横线->转角->特定物体,而类似于视觉皮层中对特定方向直线敏感的神经元,如下卷积核即可捕捉竖线特征:

其在如下窗口上取得较大的值:

我们先编写一个CNN类的最小实现,输入图像为w256*h256*c1,在torch中定义网络成分如下:

1
2
3
4
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
#...
卷积层可以这样定义self.conv1=nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,padding=1),在输入四周添加空padding可以方便地解决卷积和池化缩小输入后减、除导致大小不为整数。

然后进行标准化,self.bn1 = nn.BatchNorm2d(num_features=16)

使用池化层,可以降维减少参数数量,self.pool1=nn.MaxPool2d(kernel_size=2)

如此进行3层,通道变为64,输入大小变为256/2/2/2=32,便可以得到图像的高级特征,最后将特征送入线性的全连接层self.classifier=nn.Linear(64*32*32,5)

以及我们挑选ReLu作为激活函数elf.relu = nn.ReLU()

然后定义网络结构,规定数据流向:

1
2
3
4
5
6
7
8
9
10
11
12
def forward(self, x):
x=self.conv1(x)
x=self.bn1(x)
x=self.relu(x)
x=self.pool1(x)

#...

x=x.view(x.size(0),-1)
x=self.classifier(x)

return x
完整训练代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import pandas as pd

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import transforms
from sklearn.model_selection import train_test_split

import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device:{device}')

import random
import numpy as np
seed=42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

exps={'MedicalExpert-I','MedicalExpert-II'}
KL_levels={'0Normal','1Doubtful','2Mild','3Moderate','4Severe'}

img_exp_KLs=[]
for root, dirs, files in os.walk(f'.{os.sep}Digital Knee X-ray Images'):
for name in files:
if name.endswith('.png'):
path_parts=set(root.split(os.sep))
img_path=os.path.join(root, name)
img_exp_KL=(img_path,(path_parts&exps).pop(),(path_parts&KL_levels).pop())
img_exp_KLs.append(img_exp_KL)

dataset_df = pd.DataFrame(img_exp_KLs,columns=['img_path','exp','KL'])

dataset_df.sort_values(by='img_path',inplace=True)

dataset_df['filename'] = dataset_df['img_path'].str.split(os.sep).str[-1]
dataset_df_unique = dataset_df.drop_duplicates(subset=['KL','filename']).reset_index(drop=True)


data_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5],
std=[0.5])
])

class KneeXRayDataset(Dataset):
def __init__(self, dataframe, transform=None):
self.df = dataframe
self.transform = transform

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
row = self.df.iloc[idx]

image = Image.open(row['img_path']).convert('L')
label = int(row['KL'][0])

image = self.transform(image)

return image, label

X = dataset_df_unique
y = dataset_df_unique['KL']

X_train_val, X_test, y_train_val, y_test = train_test_split(
X, y, test_size=0.2, random_state=seed, stratify=y)

X_train, X_val, y_train, y_val = train_test_split(
X_train_val, y_train_val, test_size=0.25, random_state=seed, stratify=y_train_val)

train_dataset = KneeXRayDataset(dataframe=X_train, transform=data_transform)
val_dataset = KneeXRayDataset(dataframe=X_val, transform=data_transform)
test_dataset = KneeXRayDataset(dataframe=X_test, transform=data_transform)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

import hashlib

train_dataset_hash=hashlib.md5(''.join(X_train['filename'].to_list()).encode('utf-8')).hexdigest()
val_dataset_hash=hashlib.md5(''.join(X_val['filename'].to_list()).encode('utf-8')).hexdigest()

print(f'train_dataset_hash:{train_dataset_hash},val_dataset_hash:{val_dataset_hash}')

class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()

self.relu = nn.ReLU()

self.conv1=nn.Conv2d(1,16,kernel_size=3,padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.pool1=nn.MaxPool2d(kernel_size=2)
# (256)/2=128

self.conv2=nn.Conv2d(16,32,kernel_size=3,padding=1)
self.bn2 = nn.BatchNorm2d(32)
self.pool2=nn.MaxPool2d(kernel_size=2)
# (128)/2=64

self.conv3=nn.Conv2d(32,64,kernel_size=3,padding=1)
self.bn3 = nn.BatchNorm2d(64)
self.pool3=nn.MaxPool2d(kernel_size=2)
# (64)/2=32

self.classifier=nn.Linear(32*32*64,5)

def forward(self, x):
x=self.conv1(x)
x=self.bn1(x)
x=self.relu(x)
x=self.pool1(x)

x=self.conv2(x)
x=self.bn2(x)
x=self.relu(x)
x=self.pool2(x)

x=self.conv3(x)
x=self.bn3(x)
x=self.relu(x)
x=self.pool3(x)

x=x.view(x.size(0),-1)
x=self.classifier(x)

return x


model=SimpleCNN()

save_epoch=10
latest_model_path=f'.{os.sep}models{os.sep}model_{train_dataset_hash}_epoch{save_epoch}.pth'
while(f'.{os.sep}models{os.sep}model_{train_dataset_hash}_epoch{save_epoch}.pth'):
latest_model_path=f'.{os.sep}models{os.sep}model_{train_dataset_hash}_epoch{save_epoch}.pth'
save_epoch+=10

if os.path.exists(latest_model_path):
model.load_state_dict(torch.load(latest_model_path,map_location=device))
print(f'Loaded latest model from {latest_model_path}')

model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 100

print(f'Training for {num_epochs} epochs')

best_val_loss=float('inf')

for epoch in range(num_epochs):
model.train() # 切换训练模式
running_loss = 0.0
for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad() # 重置梯度
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward() # 损失反向传播
optimizer.step() # 更新权重
running_loss += loss.item()

model.eval() # 切换评估模式
val_loss=0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
val_loss += criterion(outputs, labels).item()

_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}, Val Accuracy: {100*correct/total}%')

if (epoch+1) % 10 == 0:
torch.save(model.state_dict(), f'.{os.sep}models{os.sep}model_{train_dataset_hash}_epoch{epoch+1}.pth')

if val_loss/len(val_loader) < best_val_loss:
best_val_loss = val_loss/len(val_loader)
torch.save(model.state_dict(), f'.{os.sep}models{os.sep}best_model_{train_dataset_hash}_on_{val_dataset_hash}.pth')

最终我们的模型在5分类上的表现为72.59%,实在有点不够看。

改进CNN

接下来我们将骨干网换为ResNet,把分类头改成5分类的线性层,并进行以下改进:

  • 对数据集进行反转、旋转、变色等增强。
  • 对损失加上各类负相关的权重。
  • L2正则化。
  • 动态学习率。
  • 早停。

我们自己构建ResNet18试试。
(写了才发现原文shortcut在downsample时用的是kernel_size=1, stride=2的卷积核,想着丢失信息可以优化,比如使用池化,结果一查已经有CVPR干过了ORZ)

构建ResNet
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet18_Weights

class BasicBlock(nn.Module):
def __init__(self,in_c,out_c,strides=[1,1],padding=1,downsample=None)->None:
super(BasicBlock, self).__init__()

self.conv1=nn.Conv2d(in_c,out_c,kernel_size=3,stride=strides[0],padding=padding,bias=False)
self.bn1=nn.BatchNorm2d(out_c)
self.relu=nn.ReLU(inplace=True)
self.conv2=nn.Conv2d(out_c,out_c,kernel_size=3,stride=strides[1],padding=padding,bias=False)
self.bn2=nn.BatchNorm2d(out_c)
self.downsample=downsample

def forward(self,x):
identity = x

out=self.conv1(x)
out=self.bn1(out)
out=self.relu(out)

out=self.conv2(out)
out=self.bn2(out)

if self.downsample is not None:
identity=self.downsample(identity)
out+=identity
out=self.relu(out)

return out

class ResNet18(nn.Module):
def __init__(self,num_classes=1000):
super(ResNet18,self).__init__()

self.in_c=64

self.conv1=nn.Conv2d(3,self.in_c,kernel_size=7,stride=2,padding=3,bias=False)
self.bn1=nn.BatchNorm2d(self.in_c)
self.relu=nn.ReLU(inplace=True)
self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

self.layer1=self._make_layer(64,strides=[1,1])
self.layer2=self._make_layer(128,strides=[2,1])
self.layer3=self._make_layer(256,strides=[2,1])
self.layer4=self._make_layer(512,strides=[2,1])

self.avgpool=nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.fc=nn.Linear(512,num_classes,bias=True)

def _make_layer(self,out_c,strides=[1,1]):
downsample=None
if strides[0]!=1:
downsample=nn.Sequential(
nn.Conv2d(self.in_c,out_c,kernel_size=1,stride=strides[0],bias=False),
nn.BatchNorm2d(out_c)
)

layers=[]
layers.append(
BasicBlock(
self.in_c,out_c,strides=[2,1],downsample=downsample,
)
)
self.in_c=out_c
layers.append(
BasicBlock(
self.in_c,out_c,
)
)

return nn.Sequential(*layers)

def forward(self,x):
x=self.conv1(x)
x=self.bn1(x)
x=self.relu(x)
x=self.maxpool(x)

x=self.layer1(x)
x=self.layer2(x)
x=self.layer3(x)
x=self.layer4(x)

x=self.avgpool(x)
x=torch.flatten(x,1)
x=self.fc(x)

return x

model=ResNet18(num_classes=1000)

model_ref=models.resnet18(weights=ResNet18_Weights.DEFAULT)

model.load_state_dict(model_ref.state_dict(),strict=True)
1
# <All keys matched successfully>

\(o)/

完整训练代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os
import pandas as pd

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import transforms
from sklearn.model_selection import train_test_split

import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

from collections import Counter

from tqdm import tqdm

import torch

import random
import numpy as np

import hashlib

class KneeXRayDataset(Dataset):
def __init__(self, dataframe, transform=None, cache_size=2000):
self.df = dataframe
self.transform = transform
self.cache={}
self.cache_size=cache_size

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
if idx in self.cache:
return self.cache[idx]

row = self.df.iloc[idx]

image = Image.open(row['img_path']).convert('RGB')
label = int(row['KL'][0])

image = self.transform(image)

result=(image, label)

if len(self.cache) < self.cache_size:
self.cache[idx] = result

return result

if __name__ == '__main__':

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device:{device}')

seed=42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

exps={'MedicalExpert-I','MedicalExpert-II'}
KL_levels={'0Normal','1Doubtful','2Mild','3Moderate','4Severe'}

img_exp_KLs=[]
for root, dirs, files in os.walk(f'.{os.sep}Digital Knee X-ray Images'):
for name in files:
if name.endswith('.png'):
path_parts=set(root.split(os.sep))
img_path=os.path.join(root, name)
img_exp_KL=(img_path,(path_parts&exps).pop(),(path_parts&KL_levels).pop())
img_exp_KLs.append(img_exp_KL)

dataset_df = pd.DataFrame(img_exp_KLs,columns=['img_path','exp','KL'])

dataset_df.sort_values(by='img_path',inplace=True)

dataset_df['filename'] = dataset_df['img_path'].str.split(os.sep).str[-1]
dataset_df_unique = dataset_df.drop_duplicates(subset=['KL','filename']).reset_index(drop=True)


data_transform_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

data_transform_val = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])


X = dataset_df_unique
y = dataset_df_unique['KL']

X_train_val, X_test, y_train_val, y_test = train_test_split(
X, y, test_size=0.2, random_state=seed, stratify=y)

X_train, X_val, y_train, y_val = train_test_split(
X_train_val, y_train_val, test_size=0.25, random_state=seed, stratify=y_train_val)

class_counts=Counter(y_train)
all_counts=len(y_train)
class_weights={cls:all_counts/cnt for cls,cnt in class_counts.items()}
class_weights=torch.tensor(list(class_weights.values()),dtype=torch.float32).to(device)

train_dataset = KneeXRayDataset(dataframe=X_train, transform=data_transform_train)
val_dataset = KneeXRayDataset(dataframe=X_val, transform=data_transform_val)
test_dataset = KneeXRayDataset(dataframe=X_test, transform=data_transform_val)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,
num_workers=4,pin_memory=True,persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False,
num_workers=4,pin_memory=True,persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


train_dataset_hash=hashlib.md5(''.join(X_train['filename'].to_list()).encode('utf-8')).hexdigest()
val_dataset_hash=hashlib.md5(''.join(X_val['filename'].to_list()).encode('utf-8')).hexdigest()

print(f'train_dataset_hash:{train_dataset_hash},val_dataset_hash:{val_dataset_hash}')


model=models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 5)


save_epoch=100
latest_model_path=f'.{os.sep}models{os.sep}model_{train_dataset_hash}_epoch{save_epoch}.pth'
while(os.path.exists(f'.{os.sep}models{os.sep}model_{train_dataset_hash}_epoch{save_epoch}.pth')):
latest_model_path=f'.{os.sep}models{os.sep}model_{train_dataset_hash}_epoch{save_epoch}.pth'
save_epoch+=100

if os.path.exists(latest_model_path):
model.load_state_dict(torch.load(latest_model_path,map_location=device))
save_epoch-=100
print(f'Loaded latest model from {latest_model_path}')
else:
save_epoch=1

model.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)

optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, min_lr=1e-7)

num_epochs = 2000

print(f'Training for {num_epochs} epochs')

best_val_loss=float('inf')

early_stopping_patience = 50
patience_counter = 0

for epoch in range(save_epoch,num_epochs+save_epoch+1):
model.train()
running_loss = 0.0
for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch}/{num_epochs}'):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()

model.eval()
val_loss=0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc=f'Epoch {epoch}/{num_epochs}'):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
val_loss += criterion(outputs, labels).item()

_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

epoch_val_loss = val_loss/len(val_loader)

scheduler.step(epoch_val_loss)

print(f"Epoch {epoch}/{num_epochs}, Loss: {running_loss/len(train_loader)}, Val Loss: {epoch_val_loss}, Val Accuracy: {100*correct/total}%, lr: {optimizer.param_groups[0]['lr']}")
with open(f'.{os.sep}models{os.sep}best_model_{train_dataset_hash}_on_{val_dataset_hash}.log','a') as f:
f.write(f"epoch:{epoch},train_loss:{running_loss/len(train_loader)},val_loss:{epoch_val_loss},val_acc:{100*correct/total}%,lr:{optimizer.param_groups[0]['lr']}\n")


if (epoch) % 100 == 0:
torch.save(model.state_dict(), f'.{os.sep}models{os.sep}model_{train_dataset_hash}_epoch{epoch}.pth')

if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
torch.save(model.state_dict(), f'.{os.sep}models{os.sep}best_model_{train_dataset_hash}_on_{val_dataset_hash}.pth')
with open(f'.{os.sep}models{os.sep}best_model_{train_dataset_hash}_on_{val_dataset_hash}.log','a') as f:
f.write(f'epoch:{epoch},saved best model\n')

patience_counter=0
else:
patience_counter+=1

if patience_counter >= early_stopping_patience:
print(f'Early stopping triggered at epoch {epoch}')
with open(f'.{os.sep}models{os.sep}best_model_{train_dataset_hash}_on_{val_dataset_hash}.log','a') as f:
f.write(f'epoch:{epoch},early stopping\n')
break

最后我们的准确率为78.01%,涨了差不多6个点。

对难以辨别的样本类进行重新分类

从混淆矩阵中可以发现,模型对1Doubtful的分类效果尤差。我们推测,1Doubtful含有符合其他类别的多种特征,人类也难以辨别。

接下来我们先从训练集中去掉这一类,然后用训练好的模型对这些样本进行重新分类。

新的4分类模型的准确率达到了87.65%,堪堪能看。

而对1Doubtful进行重新分类,使用softmax输出置信度,我们发现有很多样本十分符合其他类的特征,还有一些样本难以分辨,这其中的临床机制可能还需要进一步研究,KL分期或许还有改进的空间。

Swin Transformer

关于Swin Transformer的知识,可以去这里学习。

我们只要加载预训练模型并更换分类头就好了:

1
2
3
4
from torchvision.models.swin_transformer import swin_b, Swin_B_Weights

model = swin_b(weights=Swin_B_Weights.DEFAULT)
model.head=nn.Linear(in_features=1024, out_features=4, bias=True)

最终Swin Transformer在5分类上取得了80%的正确率,仍然是对1Doubtful的分类效果尤差。

而在4分类任务上取得了88%的正确率。

提升不是很大就不细讲了。