Skip to content

Commit 61866a1

Browse files
committed
Upload BLPSeg
1 parent 716a652 commit 61866a1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+5654
-3
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2023 Hibercraft
3+
Copyright (c) 2022 Hibercraft
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,95 @@
11
# BLPSeg
2-
The implementation of [BLPSeg: Balance the Label Preference in Scribble-Supervised Semantic Segmentation](https://github.com/YudeWang/BLPSeg).
32

4-
The code is coming soon.
3+
The implementation of [**BLPSeg: Balance the Label Preference in Scribble-Supervised Semantic Segmentation**](https://ieeexplore.ieee.org/abstract/document/10225696).
4+
5+
## Abstract
6+
7+
Scribble-supervised semantic segmentation is an appealing weakly supervised technique with low labeling cost. Existing approaches mainly consider diffusing the labeled region of scribble by low-level feature similarity to narrow the supervision gap between scribble labels and mask labels. In this study, we observe an annotation bias between scribble and object mask, i.e., label workers tend to scribble on the spacious region instead of corners. This label preference makes the model learn well on those frequently labeled regions but poor on rarely labeled pixels. Therefore, we propose BLPSeg to balance the label preference for complete segmentation. Specifically, the BLPSeg first predicts an annotation probability map to evaluate the rarity of labels on each image, then utilizes a novel BLP loss to balance the model training by up-weighting those rare annotations. Additionally, to further alleviate the impact of label preference, we design a local aggregation module (LAM) to propagate supervision from labeled to unlabeled regions in gradient backpropagation. We conduct extensive experiments to illustrate the effectiveness of our BLPSeg. Our single-stage method even outperforms other advanced multi-stage methods and achieves state-of-the-art performance.
8+
9+
## Installation
10+
11+
- Linux with Python 3.6
12+
- pytorch 1.13.0, torchvision 0.14.0
13+
- CUDA 11.7
14+
- 2 x TITAN RTX GPUs (24G)
15+
- `pip install -r requirements.txt`
16+
17+
18+
## Getting Started
19+
20+
### Preparing Dataset
21+
22+
This repository support PASCAL VOC 2012 and PASCAL-Context dataset. The datasets are organized as follow (recommend use soft link to organize):
23+
```
24+
data/
25+
VOCdevkit/
26+
VOC2012/
27+
Annotations/
28+
JPEGIMages/
29+
ImageSets/
30+
SegmentationClass/
31+
SegmentationClassAug/
32+
xxxx.png
33+
......
34+
SegmentationObject/
35+
Context/
36+
ImageSets/
37+
JPEGImages/
38+
SegmentationClass/
39+
scribble_annotation/
40+
pascal_2012/
41+
pascal_2012_label/
42+
pascal_context/
43+
pascal_context_label/
44+
```
45+
46+
1. Download PASCAL VOC 2012 dataset following [official instruction.](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/#devkit)
47+
2. Download PASCAL VOC 2012 trainaug set (including 10582 images) from [here](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0), place the folder at `data/VOCdevkit/SegmentationClassAug/`.
48+
3. Generate training list file `data/VOCdevkit/ImageSets/trainaug.txt` for trainaug set (1464 images from official VOC12 dataset + additional 9118 images determined by `data/VOCdevkit/VOC2012/SegmentationClassAug`)
49+
```
50+
cd data
51+
python generate_trainauglist.py
52+
```
53+
4. Download PASCAL-Context dataset from [here.](https://www.cs.stanford.edu/~roozbeh/pascal-context/)
54+
5. Download scribble annotation from [PASCAL-Scribble.](https://jifengdai.org/downloads/scribble_sup/) Convert `.xml` scribble annotation files into `.png` pixel-level annotation format
55+
```
56+
cd data
57+
python xml2png_voc.py
58+
python xml2png_context.py
59+
```
60+
61+
### Train & Evaluation
62+
63+
We take the experiments on PASCAL VOC 2012 as example. Firstly switch to the experiment folder.
64+
```
65+
cd experiment/blpseg-voc
66+
```
67+
Please setup the corresponding settings in `config.py` then run:
68+
```
69+
python train.py
70+
```
71+
Check the `config_dict['TEST_CKPT']` in `config.py` and run evaluation script:
72+
```
73+
python test.py
74+
```
75+
## Model Zoo
76+
77+
| Model | Dataset | mIoU% (w/o CRF) | Download|
78+
|:------|:--------|------|---------|
79+
| BLPSeg-res101 | PASCAL VOC 2012 | 77.559 | [Google Drive](https://drive.google.com/file/d/13UJZOZVIZDkdbYAhEANJks8in2sbCD93/view?usp=sharing)/[Baiduyun Drive](https://pan.baidu.com/s/1iuKk-8AgMjK78SyEOtj_ow?pwd=d9ie)(code: d9ie) |
80+
| BLPSeg-res101 | PASCAL-Context | 45.745 | [Google Drive](https://drive.google.com/file/d/1TiVU2toU6wr1_xa6nbVuP29up_Wt4tff/view?usp=sharing)/[Baiduyun Drive](https://pan.baidu.com/s/155noxNOA9EnTZ4_6Yy01sA?pwd=pls1)(code: pls1) |
81+
82+
## Citations
83+
84+
Please cite our paper if the code is helpful to your research.
85+
86+
```
87+
@article{wang2023blpseg,
88+
title={BLPSeg: Balance the Label Preference in Scribble-Supervised Semantic Segmentation},
89+
author={Wang, Yude and Zhang, Jie and Kan, Meina and Shan, Shiguang and Chen, Xilin},
90+
journal={IEEE Transactions on Image Processing},
91+
year={2023},
92+
publisher={IEEE}
93+
}
94+
```
95+

data/generate_trainauglist.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
import argparse
3+
import pandas as pd
4+
5+
if __name__ == '__main__':
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument('--list_folder', type=str, default='./VOCdevkit/VOC2012/ImageSets/Segmentation')
8+
parser.add_argument('--aug_folder', type=str, default='./VOCdevkit/VOC2012/SegmentationClassAug')
9+
args = parser.parse_args()
10+
11+
train_file = os.path.join(args.list_folder, 'train.txt')
12+
val_file = os.path.join(args.list_folder, 'val.txt')
13+
trainaug_file = os.path.join(args.list_folder, 'trainaug.txt')
14+
train_list = pd.read_csv(train_file, names=['filename'])['filename'].values
15+
val_list = pd.read_csv(val_file, names=['filename'])['filename'].values
16+
files = os.listdir(args.aug_folder)
17+
trainaug_txt_file = open(trainaug_file, 'w')
18+
for f in files:
19+
fname = f[:-4]
20+
if fname not in val_list:
21+
trainaug_txt_file.write(f[:-4]+'\n')
22+
trainaug_txt_file.close()
23+

data/xml2png_context.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os
2+
import argparse
3+
import numpy as np
4+
import xml.dom.minidom as minidom
5+
from xml.dom.minidom import parse
6+
from PIL import Image
7+
from tqdm import tqdm
8+
import time
9+
10+
def xml2dict(xml_file):
11+
result = {}
12+
tree = minidom.parse(xml_file)
13+
collection = tree.documentElement
14+
size = collection.getElementsByTagName('size')[0]
15+
h = int(size.getElementsByTagName('height')[0].childNodes[0].data)
16+
w = int(size.getElementsByTagName('width')[0].childNodes[0].data)
17+
result['size'] = (h,w)
18+
result['filename'] = collection.getElementsByTagName('filename')[0].childNodes[0].data
19+
polygons = collection.getElementsByTagName('polygon')
20+
polygon_list = []
21+
for polygon in polygons:
22+
single_polygon_dict = {}
23+
single_polygon_dict['category'] = polygon.getElementsByTagName('tag')[0].childNodes[0].data
24+
points = polygon.getElementsByTagName('point')
25+
point_list = []
26+
for point in points:
27+
x = int(point.getElementsByTagName('X')[0].childNodes[0].data)
28+
y = int(point.getElementsByTagName('Y')[0].childNodes[0].data)
29+
x = max(min(x,w-1),0)
30+
y = max(min(y,h-1),0)
31+
point_list.append((y,x))
32+
single_polygon_dict['points'] = point_list
33+
polygon_list.append(single_polygon_dict)
34+
result['polygons'] = polygon_list
35+
return result
36+
37+
def drawline(img, pos1, pos2, value):
38+
r1,c1 = pos1
39+
r2,c2 = pos2
40+
m = max(np.abs(r1-r2), np.abs(c1-c2))
41+
if m <= 1:
42+
return img
43+
delta_r = (r2-r1)/m
44+
delta_c = (c2-c1)/m
45+
for i in range(m):
46+
r = int(r1 + delta_r*i)
47+
c = int(c1 + delta_c*i)
48+
img[r,c] = value
49+
return img
50+
51+
if __name__ == '__main__':
52+
parser = argparse.ArgumentParser()
53+
parser.add_argument('--xml', type=str, default='./VOCdevkit/scribble_annotation/pascal_context')
54+
parser.add_argument('--save', type=str, default='./VOCdevkit/scribble_annotation/pascal_context_label')
55+
args = parser.parse_args()
56+
57+
cls2idx = {'background':0, 'plane': 1, 'bike': 2, 'bird': 3, 'boat': 4, 'bottle': 5, 'bus': 6,
58+
'car': 7, 'cat': 8, 'chair': 9, 'cow': 10, 'table': 11, 'dog': 12, 'horse': 13, 'motorbike': 14,
59+
'person': 15, 'plant': 16, 'sheep': 17, 'sofa': 18, 'train': 19, 'monitor': 20, 'bag': 21, 'bed': 22,
60+
'bench': 23, 'book': 24, 'building': 25, 'cabinet': 26, 'ceiling': 27, 'cloth': 28, 'computer': 29,
61+
'cup': 30, 'door': 31, 'fence': 32, 'floor': 33, 'flower': 34, 'food': 35, 'grass': 36, 'ground': 37,
62+
'keyboard': 38, 'light': 39, 'mountain': 40, 'mouse': 41, 'curtain': 42, 'platform': 43, 'sign': 44,
63+
'plate': 45, 'road': 46, 'rock': 47, 'shelves': 48, 'sidewalk': 49, 'sky': 50, 'snow': 51, 'bedclothes': 52,
64+
'track': 53, 'tree': 54, 'truck': 55, 'wall': 56, 'water': 57, 'window': 58, 'wood': 59}
65+
g = os.walk(args.xml)
66+
if not os.path.exists(args.save):
67+
os.makedirs(args.save)
68+
69+
for path, dir_list, file_list in g:
70+
with tqdm(total=len(file_list)) as pbar:
71+
pbar.set_description('Processing:')
72+
for file_name in file_list:
73+
filename = os.path.join(path, file_name)
74+
info = xml2dict(filename)
75+
label = np.ones(info['size'])*255
76+
for polygon in info['polygons']:
77+
clsidx = cls2idx[polygon['category']]
78+
for i in range(len(polygon['points'])-1):
79+
point1 = polygon['points'][i]
80+
point2 = polygon['points'][i+1]
81+
label = drawline(label, point1, point2, clsidx)
82+
label[point1] = clsidx
83+
label[point2] = clsidx
84+
label = label.astype(np.uint8)
85+
label = Image.fromarray(label)
86+
out_name = os.path.join(args.save, file_name.replace('.xml','.png'))
87+
label.save(out_name)
88+
time.sleep(0.01)
89+
pbar.update(1)

data/xml2png_voc.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
import argparse
3+
import numpy as np
4+
import xml.dom.minidom as minidom
5+
from xml.dom.minidom import parse
6+
from PIL import Image
7+
from tqdm import tqdm
8+
import time
9+
10+
def xml2dict(xml_file):
11+
result = {}
12+
tree = minidom.parse(xml_file)
13+
collection = tree.documentElement
14+
size = collection.getElementsByTagName('size')[0]
15+
h = int(size.getElementsByTagName('height')[0].childNodes[0].data)
16+
w = int(size.getElementsByTagName('width')[0].childNodes[0].data)
17+
result['size'] = (h,w)
18+
result['filename'] = collection.getElementsByTagName('filename')[0].childNodes[0].data
19+
polygons = collection.getElementsByTagName('polygon')
20+
polygon_list = []
21+
for polygon in polygons:
22+
single_polygon_dict = {}
23+
single_polygon_dict['category'] = polygon.getElementsByTagName('tag')[0].childNodes[0].data
24+
points = polygon.getElementsByTagName('point')
25+
point_list = []
26+
for point in points:
27+
x = int(point.getElementsByTagName('X')[0].childNodes[0].data)
28+
y = int(point.getElementsByTagName('Y')[0].childNodes[0].data)
29+
x = max(min(x,w-1),0)
30+
y = max(min(y,h-1),0)
31+
point_list.append((y,x))
32+
single_polygon_dict['points'] = point_list
33+
polygon_list.append(single_polygon_dict)
34+
result['polygons'] = polygon_list
35+
return result
36+
37+
def drawline(img, pos1, pos2, value):
38+
r1,c1 = pos1
39+
r2,c2 = pos2
40+
m = max(np.abs(r1-r2), np.abs(c1-c2))
41+
if m <= 1:
42+
return img
43+
delta_r = (r2-r1)/m
44+
delta_c = (c2-c1)/m
45+
for i in range(m):
46+
r = int(r1 + delta_r*i)
47+
c = int(c1 + delta_c*i)
48+
img[r,c] = value
49+
return img
50+
51+
if __name__ == '__main__':
52+
parser = argparse.ArgumentParser()
53+
parser.add_argument('--xml', type=str, default='./VOCdevkit/scribble_annotation/pascal_2012')
54+
parser.add_argument('--save', type=str, default='./VOCdevkit/scribble_annotation/pascal_2012_label')
55+
args = parser.parse_args()
56+
57+
cls2idx = {'background':0, 'plane':1, 'bike':2, 'bird':3, 'boat':4, 'bottle':5, 'bus':6, 'car':7, 'cat':8, 'chair':9, 'cow':10,
58+
'table':11, 'dog':12, 'horse':13, 'motorbike':14, 'person':15, 'plant':16, 'sheep':17, 'sofa':18, 'train':19, 'monitor':20}
59+
g = os.walk(args.xml)
60+
if not os.path.exists(args.save):
61+
os.makedirs(args.save)
62+
for path, dir_list, file_list in g:
63+
with tqdm(total=len(file_list)) as pbar:
64+
pbar.set_description('Processing:')
65+
for file_name in file_list:
66+
filename = os.path.join(path, file_name)
67+
info = xml2dict(filename)
68+
label = np.ones(info['size'])*255
69+
for polygon in info['polygons']:
70+
clsidx = cls2idx[polygon['category']]
71+
for i in range(len(polygon['points'])-1):
72+
point1 = polygon['points'][i]
73+
point2 = polygon['points'][i+1]
74+
label = drawline(label, point1, point2, clsidx)
75+
label[point1] = clsidx
76+
label[point2] = clsidx
77+
label = label.astype(np.uint8)
78+
label = Image.fromarray(label)
79+
out_name = os.path.join(args.save, file_name.replace('.xml','.png'))
80+
label.save(out_name)
81+
time.sleep(0.01)
82+
pbar.update(1)

experiment/blpseg-context/__init__.py

Whitespace-only changes.

experiment/blpseg-context/config.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# ----------------------------------------
2+
# Written by Yude Wang
3+
# ----------------------------------------
4+
import torch
5+
import argparse
6+
import os
7+
import sys
8+
import cv2
9+
import time
10+
11+
config_dict = {
12+
'EXP_NAME': 'blpseg-context',
13+
14+
'DATA_NAME': 'ContextDataset',
15+
'DATA_YEAR': 2012,
16+
'DATA_AUG': True,
17+
'DATA_WORKERS': 4,
18+
'DATA_MEAN': [0.485, 0.456, 0.406],
19+
'DATA_STD': [0.229, 0.224, 0.225],
20+
'DATA_RANDOMSCALE': [0.75, 1.25],
21+
'DATA_RANDOM_H': 10,
22+
'DATA_RANDOM_S': 10,
23+
'DATA_RANDOM_V': 10,
24+
'DATA_RANDOMCROP': 384,
25+
'DATA_RANDOMROTATION': 0,
26+
'DATA_RANDOMFLIP': 0.5,
27+
28+
'MODEL_NAME': 'BLPSeg',
29+
'MODEL_BACKBONE': 'resnet101',
30+
'MODEL_BACKBONE_PRETRAIN': True,
31+
'MODEL_PPM_DIM': 256,
32+
'MODEL_NUM_CLASSES': 60,
33+
'MODEL_FREEZEBN': False,
34+
'MODEL_LAM_SIGMA': 6,
35+
36+
'LOSS_GAMMA': 2.0,
37+
'LOSS_UNLABEL_CLASS_W': 0.02,
38+
39+
'TRAIN_LR': 2.4e-5,
40+
'TRAIN_MOMENTUM': 0.9,
41+
'TRAIN_WEIGHT_DECAY': 0.01,
42+
'TRAIN_BN_MOM': 0.1,
43+
'TRAIN_POWER': 0.9,
44+
'TRAIN_BATCHES': 8,
45+
'TRAIN_SHUFFLE': False,
46+
'TRAIN_MINEPOCH': 0,
47+
'TRAIN_EPOCHS': 85,
48+
'TRAIN_TBLOG': True,
49+
'TRAIN_ST_POINT': 20000,
50+
51+
'TEST_MULTISCALE': [0.5, 0.75, 1, 1.25],
52+
'TEST_FLIP': True,
53+
'TEST_CRF': False,
54+
'TEST_BATCHES': 1,
55+
}
56+
57+
config_dict['ROOT_DIR'] = os.path.abspath(os.path.join(os.path.dirname("__file__"),'..','..'))
58+
config_dict['MODEL_SAVE_DIR'] = os.path.join(config_dict['ROOT_DIR'],'model',config_dict['EXP_NAME'])
59+
config_dict['TRAIN_CKPT'] = None
60+
config_dict['LOG_DIR'] = os.path.join(config_dict['ROOT_DIR'],'log',config_dict['EXP_NAME'])
61+
config_dict['TEST_CKPT'] = os.path.join(config_dict['ROOT_DIR'],f'model/{config_dict["EXP_NAME"]}/BLPSeg_resnet101_ContextDataset_epoch85.pth')
62+
63+
sys.path.insert(0, os.path.join(config_dict['ROOT_DIR'], 'lib'))

0 commit comments

Comments
 (0)