医疗图像语义分割

15k 词

前言

传统的CNN逐层提取特征,可以将最后卷积、池化输出的特征展平送入全连接层,得到整张图片的分类。Jonathan Long等人在得到最终特征图后不使用全连接层,而是将这个低分辨率高级特征图转化成分类,最后进行上采样,于是就可以预测每一像素的分类,得到分类掩码,从此揭开了图像分割的帷幕。

在这里,我们在FracAtlas以及Leprosy Chronic Wound Images (CO2Wounds-V2)这两个数据集上,使用FCN和U-net进行语义分割,分别得到骨折部位和创面部位的分割结果。

FracAtlas数据集

FCN的实现细节,在这一篇里讲得比较清楚,最后的全连接层被替换掉,经过

FCN Head:

  • 3x3卷积,缩减通道数减少计算压力
  • 标准化和激活
  • dropout层正则化
  • 1x1卷积将特征通道转化成类别通道

然后将其输出的低分辨率掩码通过插值或转置卷积等手段进行上采样,就可以得到和原始输入宽高一致的逐像素分类。

(虽然Pytorch官方的实现没有涉及Skip Connection,但也方便我们和之后有Skip Connection的U-net相比较)

FracAltas COCO标注转掩码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import cv2
import random
import os
import numpy as np

from pycocotools.coco import COCO
from skimage import io

import os
def coco2mask(img_idx, coco_json, img_path, str = ' ', save_dir = None):
# 需要画图的是第num副图片, 对应的json路径和图片路径,
# str = ' '为类别字符串,输入必须为字符串形式 'str',若为空,则返回所有类别id
coco = COCO(coco_json)

catIds = coco.getCatIds(catNms=[str]) # 获取指定类别 id

imgIds = coco.getImgIds(catIds=catIds ) # 获取图片i
img = coco.loadImgs(imgIds[img_idx-1])[0] # 加载图片,loadImgs() 返回的是只有一个内嵌字典元素的list, 使用[0]来访问这个元素
image = io.imread(img_path + img['file_name'])

annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
anns = coco.loadAnns(annIds)

mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
for ann in anns:
segs = ann['segmentation']
for seg in segs:
seg = np.array(seg).reshape(-1, 2)
cv2.fillPoly(mask, seg.astype(np.int32)[np.newaxis, :, :], 255)
cv2.imwrite(os.path.join(save_dir, img['file_name'].replace('.jpg', '.png')), mask)

coco_json = 'FracAtlas/Annotations/COCO JSON/COCO_fracture_masks.json'
img_path = 'FracAtlas/images/Fractured/'

save_dir = 'FracAtlas/masks'
if not os.path.exists(save_dir):
os.makedirs(save_dir)

img_idx2mask = 1

while True:
try:
coco2mask(img_idx2mask, coco_json, img_path, save_dir=save_dir)
img_idx2mask+=1
except IndexError:
break

切记啊切记,jpg是有损压缩,这也就是为什么下文出现的Ground Truth掩码很奇怪。

dataset类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os

from torch.utils.data import Dataset

from PIL import Image

import numpy as np

from torchvision.transforms import v2

from torch import float32 as torch_float32

from torchvision import tv_tensors

class SegmentationPresetTrain:
def __init__(self, base_size, crop_size, hflip_prop = 0.5, vflip_prop = 0.5, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)):

_transforms = []
if hflip_prop > 0:
_transforms.append(v2.RandomHorizontalFlip(hflip_prop))
if vflip_prop > 0:
_transforms.append(v2.RandomVerticalFlip(vflip_prop))
_transforms.append(v2.RandomResizedCrop(crop_size, scale=(0.4, 1.0)))
_transforms.append(v2.RandomRotation(15))
_transforms.append(v2.ColorJitter(brightness = 0.2, contrast = 0.2))
_transforms.append(v2.ToImage())
_transforms.append(v2.ToDtype(torch_float32, scale=True))

_transforms.append(v2.Normalize(mean = mean, std = std))
self.transforms = v2.Compose(_transforms)

def __call__(self, img, mask):
return self.transforms(img, mask)

class SegmentationPresetVal:
def __init__(self, base_size, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)):
self.transforms = v2.Compose([
v2.Resize((base_size, base_size)),
v2.ToImage(),
v2.ToDtype(torch_float32, scale=True),
v2.Normalize(mean = mean, std = std)
])

def __call__(self, img, mask):
return self.transforms(img, mask)

def get_transform(train: bool = True, base_size = 1024, crop_size = 480):
if train:
return SegmentationPresetTrain(base_size, crop_size)
else:
return SegmentationPresetVal(base_size)

class FracAtlasDataset(Dataset):
def __init__(self, img_dir, mask_dir, img_mask_names, transforms = None) -> None:
super().__init__()
self.images = [os.path.join(img_dir, img_name) for img_name in img_mask_names]
self.masks = [os.path.join(mask_dir, mask_name.replace('.jpg', '.png')) for mask_name in img_mask_names]

self.transforms = transforms

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
image = Image.open(os.path.join(self.images[idx])).convert('RGB')
mask = Image.open(os.path.join(self.masks[idx])).convert('L')

mask = np.array(mask, dtype=np.uint8)
mask = (mask > 0).astype(np.uint8)
mask = Image.fromarray(mask, mode='L')

# image = tv_tensors.Image(image)
mask = tv_tensors.Mask(mask)

if self.transforms is not None:
image, mask = self.transforms(image, mask)

return image, mask

@staticmethod
def _cat_list(images, fill_value = 0):
max_size = tuple(max(s_1d) for s_1d in zip(*[image.shape for image in images]) )
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new_full(batch_shape, fill_value)
for image, padded_image in zip(images, batched_imgs):
padded_image[..., :image.shape[-2], :image.shape[-1]].copy_(image)
return batched_imgs

@staticmethod
def collate_fn(batch):
images, mask = list(zip(*batch))

batched_images = FracAtlasDataset._cat_list(images)
batched_mask = FracAtlasDataset._cat_list(mask)
return batched_images, batched_mask
完整训练代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os

import re

from sklearn.model_selection import train_test_split

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.models.segmentation import fcn_resnet50
from torchvision.models.segmentation import FCN_ResNet50_Weights

from tqdm import tqdm

import torch

import random
import numpy as np

import hashlib

import dataset

from sklearn.metrics import confusion_matrix

random_state = 42
random.seed(random_state); np.random.seed(random_state); torch.manual_seed(random_state)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(random_state)

batch_size = 8

if __name__ == '__main__':

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device:{device}')

img_dir = 'FracAtlas/images/Fractured'
mask_dir = 'FracAtlas/masks'

img_mask_names = sorted(os.listdir(img_dir))

X_train_val, X_test, y_train_val, y_test = train_test_split(img_mask_names, img_mask_names, test_size=0.2, random_state=random_state)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.25, random_state=random_state)

train_dataset = dataset.FracAtlasDataset(img_dir, mask_dir, X_train, dataset.get_transform(train=True))
val_dataset = dataset.FracAtlasDataset(img_dir, mask_dir, X_val, dataset.get_transform(train=False))
test_dataset = dataset.FracAtlasDataset(img_dir, mask_dir, X_test, dataset.get_transform(train=False))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True, collate_fn=train_dataset.collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True, collate_fn=val_dataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True, collate_fn=test_dataset.collate_fn)

train_dataset_hash=hashlib.md5(''.join(X_train).encode('utf-8')).hexdigest()
val_dataset_hash=hashlib.md5(''.join(X_val).encode('utf-8')).hexdigest()
test_dataset_hash=hashlib.md5(''.join(X_test).encode('utf-8')).hexdigest()
print(f'train_dataset_hash:{train_dataset_hash},val_dataset_hash:{val_dataset_hash},test_dataset_hash:{test_dataset_hash}')

model = fcn_resnet50(weights=FCN_ResNet50_Weights.DEFAULT, aux_loss=True)

model.classifier[4] = nn.Conv2d(512, 2, kernel_size=1, stride=1)
model.aux_classifier[4] = nn.Conv2d(256, 2, kernel_size=1, stride=1)

save_epoch=10

# 读取最后模型
saved_models = os.listdir(f'.{os.sep}models{os.sep}')
saved_models = [model for model in saved_models if model.startswith(f'model_{train_dataset_hash}_on_{val_dataset_hash}_epoch')]
saved_models_epochs = [int(model.split('_')[-1].split('.')[0].replace('epoch', '')) for model in saved_models]
saved_models = zip(saved_models_epochs, saved_models)
saved_models = sorted(saved_models, key=lambda x: x[0], reverse=True)

if len(saved_models) > 0:
latest_model_path = f'.{os.sep}models{os.sep}{saved_models[0][1]}'
model.load_state_dict(torch.load(latest_model_path,map_location=device))
load_epoch=saved_models[0][0]
print(f'Loaded latest model from {latest_model_path}')
else:
load_epoch=0
print('No saved models found')


# 读取日志获取最后lr,最佳val loss和停滞epoch数
try:
with open(f'.{os.sep}models{os.sep}log_{train_dataset_hash}_on_{val_dataset_hash}.log','r') as f:
log_lines = f.readlines()
except FileNotFoundError:
log_lines = []

train_record_reg = re.compile(r'epoch:(\d+),train_loss:(\d+\.\d+),val_loss:(\d+\.\d+),val_acc:(\d+\.\d+)%,lr:(.+)')

# 学习率处理
lr_last = 1e-5
for i in range(len(log_lines)-1, -1, -1):
matched_record = re.search(train_record_reg, log_lines[i])
if matched_record:
if int(matched_record.group(1)) == load_epoch:
lr_last = float(matched_record.group(5))
print(f'Using last lr: {lr_last}')
break

# load_epoch之前最近的最佳模型指标处理
best_epoch = 0
best_val_loss_avg_last=float('inf')
for i in range(len(log_lines)-1, -1, -1):

if 'saved best model' in log_lines[i]:

for j in range(i, -1, -1):
matched_record = re.search(train_record_reg, log_lines[j])
if matched_record:
if int(matched_record.group(1)) > load_epoch:
break
else:
best_epoch = int(matched_record.group(1))
best_val_loss_avg_last = float(matched_record.group(3))
print(f'Using last best val loss avg: {best_val_loss_avg_last}')
break

if best_val_loss_avg_last < float('inf'):
break

model.to(device)

def criterion(outputs, targets, train: bool = True):
loss_main = nn.functional.cross_entropy(outputs['out'], targets)

if not train:
return loss_main

return loss_main + 0.5 * nn.functional.cross_entropy(outputs['aux'], targets)

optimizer = optim.Adam(model.parameters(), lr=lr_last, weight_decay=1e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10, min_lr=1e-7)

num_epochs_train = 1000

print(f"Training for {num_epochs_train} epochs")

best_val_loss_avg = best_val_loss_avg_last

early_stopping_patience = 50
patience_counter = max(0,load_epoch - best_epoch)
if patience_counter > 0:
print(f'Using last patience counter: {patience_counter}')

for epoch in range(load_epoch + 1, num_epochs_train + load_epoch + 1):
model.train()
running_loss = 0.0
for inputs, targets in tqdm(train_loader, desc=f'Epoch {epoch}/{num_epochs_train + load_epoch}'):
inputs, targets = inputs.to(device), targets.to(device)

targets = targets.squeeze(1).long()

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets, train=True)
loss.backward()
optimizer.step()
running_loss += loss.item()

model.eval()
val_loss = 0.0
tn, fp, fn, tp = 0, 0, 0, 0

with torch.no_grad():
for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch}/{num_epochs_train + load_epoch}"):
inputs, targets = inputs.to(device), targets.to(device)

targets = targets.squeeze(1).long()

outputs = model(inputs)
val_loss += criterion(outputs, targets, train=False).item()

predicts = outputs['out'].argmax(dim=1)

predicts_flat = predicts.flatten().cpu().numpy()
targets_flat = targets.flatten().cpu().numpy()

tn_batch, fp_batch, fn_batch, tp_batch = confusion_matrix(targets_flat, predicts_flat).ravel()
tn += tn_batch
fp += fp_batch
fn += fn_batch
tp += tp_batch

val_loss_avg = val_loss/len(val_loader)

scheduler.step(val_loss_avg)


print(f"Epoch {epoch}/{num_epochs_train + load_epoch}, Loss: {running_loss/len(train_loader):.3f}, Val Loss: {val_loss_avg:.3f}, Val Accuracy: {100*(tp+tn)/(tp+fp+fn+tn):.3f}%, lr: {optimizer.param_groups[0]['lr']}")
print(f"Epoch {epoch}/{num_epochs_train + load_epoch}, Sensitivity(Recall): {tp/(tp+fn):.3f}, Specificity: {tn/(tn+fp):.3f}, Precision: {tp/(tp+fp):.3f}")
with open(f'.{os.sep}models{os.sep}log_{train_dataset_hash}_on_{val_dataset_hash}.log','a') as f:
f.write(f"epoch:{epoch},train_loss:{running_loss/len(train_loader)},val_loss:{val_loss_avg},val_acc:{100*(tp+tn)/(tp+fp+fn+tn)}%,lr:{optimizer.param_groups[0]['lr']}\n")
f.write(f"epoch:{epoch},sensitivity(Recall):{tp/(tp+fn)},specificity:{tn/(tn+fp)},precision:{tp/(tp+fp)}\n")

if epoch % save_epoch == 0:
torch.save(model.state_dict(), f'.{os.sep}models{os.sep}model_{train_dataset_hash}_on_{val_dataset_hash}_epoch{epoch}.pth')

if val_loss_avg < best_val_loss_avg:
best_val_loss_avg = val_loss_avg
torch.save(model.state_dict(), f'.{os.sep}models{os.sep}best_model_{train_dataset_hash}_on_{val_dataset_hash}.pth')
print(f'epoch:{epoch},saved best model')
with open(f'.{os.sep}models{os.sep}log_{train_dataset_hash}_on_{val_dataset_hash}.log','a') as f:
f.write(f'epoch:{epoch},saved best model\n')

patience_counter=0
else:
patience_counter+=1

if patience_counter >= early_stopping_patience:
print(f'Early stopping triggered at epoch {epoch}')
with open(f'.{os.sep}models{os.sep}log_{train_dataset_hash}_on_{val_dataset_hash}.log','a') as f:
f.write(f'epoch:{epoch},early stopping\n')
break
最佳模型是
1
epoch:110,sensitivity(Recall):0.3831927899240895,specificity:0.9985073638473289,precision:0.6006053270923088
好的结果
不好的结果

最后结果并不是很好,并且换了U-net以及加权交叉熵损失、Dice损失之类的之后也一样,我认为和这个数据集的标注逻辑有关。骨折的地方就那个断面,周围标注可大可小,这种应该比较适合bbox标注做目标检测。

所以下文我们使用CO2Wounds-V2这个数据集。

CO2Wounds-V2数据集

FCN

用相似的代码进行FCN的训练,最佳模型是

1
epoch:66,sensitivity(Recall):0.8410347542721659,specificity:0.984027932237679,precision:0.8603991908024482

结果预览

我感觉甚至可以说在一些地方,模型分割得比Ground Truth要好一些。

U-net

没有Skip Connection的FCN只能对最后的低分辨率分割进行强行上采样,边缘的细微特征难以捕捉,而U-net解决了这个问题。

U-net的U形象地表示了它的数据流动(是不是叫日-net好一点),特征图在U这个线条上从左上变换到右上,深度代表了降采样力度,并且在左右相同分辨率之间有Skip Connection,于是左半边进行特征提取,右半边进行多尺度特征融合,从原图到最低分辨率之间的特征图都能捕捉到,理论上会有更好的分割边界。

我们通过segmentation-models-pytorch包来使用U-net,以及高级点的Dice损失等。

最佳模型是
1
epoch:122,sensitivity(Recall):0.820,specificity:0.986,precision:0.869
结果预览

二者其实差不多嘛

最佳模型

经过多次实验,最好的模型是U+net+Dice loss
1
epoch:133,sensitivity(Recall):0.904,specificity:0.981,precision:0.849
结果预览

PS

我发现,有些时候让模型做二分类任务(前景背景分别预测)和正反例分类(只预测前景然后截断)任务,表现都差不多,但是训练起来后者loss下降要慢得多,怎么回事呢。

还有,.jpg是有损压缩,.png是无损压缩,存了.jpg的mask可视化时才发现问题...

以及,如果用了很多共享显存时训练慢得要死,虽然会会显示为满占用。