Skip to content

Commit 94aabdb

Browse files
committed
Added object detection
1 parent a022a6e commit 94aabdb

19 files changed

+1214
-1
lines changed

.DS_Store

0 Bytes
Binary file not shown.

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ Machine_Learning_A-Z.Rproj
66
.DS_Store
77
Deep-Learning/Convolutional-Neural-Networks(CNN)/dataset
88
.gitignore
9-
*.pyc
9+
*.pyc
10+
*.pth
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from .voc0712 import VOCDetection, AnnotationTransform, detection_collate, VOC_CLASSES
2+
from .config import *
3+
import cv2
4+
import numpy as np
5+
6+
7+
def base_transform(image, size, mean):
8+
x = cv2.resize(image, (size, size)).astype(np.float32)
9+
# x = cv2.resize(np.array(image), (size, size)).astype(np.float32)
10+
x -= mean
11+
x = x.astype(np.float32)
12+
return x
13+
14+
15+
class BaseTransform:
16+
def __init__(self, size, mean):
17+
self.size = size
18+
self.mean = np.array(mean, dtype=np.float32)
19+
20+
def __call__(self, image, boxes=None, labels=None):
21+
return base_transform(image, self.size, self.mean), boxes, labels
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# config.py
2+
import os.path
3+
4+
# gets home dir cross platform
5+
home = os.path.expanduser("~")
6+
ddir = os.path.join(home,"data/VOCdevkit/")
7+
8+
# note: if you used our download scripts, this should be right
9+
VOCroot = ddir # path to VOCdevkit root dir
10+
11+
# default batch size
12+
BATCHES = 32
13+
# data reshuffled at every epoch
14+
SHUFFLE = True
15+
# number of subprocesses to use for data loading
16+
WORKERS = 4
17+
18+
19+
#SSD300 CONFIGS
20+
# newer version: use additional conv11_2 layer as last layer before multibox layers
21+
v2 = {
22+
'feature_maps' : [38, 19, 10, 5, 3, 1],
23+
24+
'min_dim' : 300,
25+
26+
'steps' : [8, 16, 32, 64, 100, 300],
27+
28+
'min_sizes' : [30, 60, 111, 162, 213, 264],
29+
30+
'max_sizes' : [60, 111, 162, 213, 264, 315],
31+
32+
# 'aspect_ratios' : [[2, 1/2], [2, 1/2, 3, 1/3], [2, 1/2, 3, 1/3],
33+
# [2, 1/2, 3, 1/3], [2, 1/2], [2, 1/2]],
34+
'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
35+
36+
'variance' : [0.1, 0.2],
37+
38+
'clip' : True,
39+
40+
'name' : 'v2',
41+
}
42+
43+
# use average pooling layer as last layer before multibox layers
44+
v1 = {
45+
'feature_maps' : [38, 19, 10, 5, 3, 1],
46+
47+
'min_dim' : 300,
48+
49+
'steps' : [8, 16, 32, 64, 100, 300],
50+
51+
'min_sizes' : [30, 60, 114, 168, 222, 276],
52+
53+
'max_sizes' : [-1, 114, 168, 222, 276, 330],
54+
55+
# 'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]],
56+
'aspect_ratios' : [[1,1,2,1/2],[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3],
57+
[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3]],
58+
59+
'variance' : [0.1, 0.2],
60+
61+
'clip' : True,
62+
63+
'name' : 'v1',
64+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/bin/bash
2+
# Ellis Brown
3+
4+
start=`date +%s`
5+
6+
# handle optional download dir
7+
if [ -z "$1" ]
8+
then
9+
# navigate to ~/data
10+
echo "navigating to ~/data/ ..."
11+
mkdir -p ~/data
12+
cd ~/data/
13+
else
14+
# check if is valid directory
15+
if [ ! -d $1 ]; then
16+
echo $1 "is not a valid directory"
17+
exit 0
18+
fi
19+
echo "navigating to" $1 "..."
20+
cd $1
21+
fi
22+
23+
echo "Downloading VOC2007 trainval ..."
24+
# Download the data.
25+
curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
26+
echo "Downloading VOC2007 test data ..."
27+
curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
28+
echo "Done downloading."
29+
30+
# Extract data
31+
echo "Extracting trainval ..."
32+
tar -xvf VOCtrainval_06-Nov-2007.tar
33+
echo "Extracting test ..."
34+
tar -xvf VOCtest_06-Nov-2007.tar
35+
echo "removing tars ..."
36+
rm VOCtrainval_06-Nov-2007.tar
37+
rm VOCtest_06-Nov-2007.tar
38+
39+
end=`date +%s`
40+
runtime=$((end-start))
41+
42+
echo "Completed in" $runtime "seconds"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/bin/bash
2+
# Ellis Brown
3+
4+
start=`date +%s`
5+
6+
# handle optional download dir
7+
if [ -z "$1" ]
8+
then
9+
# navigate to ~/data
10+
echo "navigating to ~/data/ ..."
11+
mkdir -p ~/data
12+
cd ~/data/
13+
else
14+
# check if is valid directory
15+
if [ ! -d $1 ]; then
16+
echo $1 "is not a valid directory"
17+
exit 0
18+
fi
19+
echo "navigating to" $1 "..."
20+
cd $1
21+
fi
22+
23+
echo "Downloading VOC2012 trainval ..."
24+
# Download the data.
25+
curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
26+
echo "Done downloading."
27+
28+
29+
# Extract data
30+
echo "Extracting trainval ..."
31+
tar -xvf VOCtrainval_11-May-2012.tar
32+
echo "removing tar ..."
33+
rm VOCtrainval_11-May-2012.tar
34+
35+
end=`date +%s`
36+
runtime=$((end-start))
37+
38+
echo "Completed in" $runtime "seconds"
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
"""VOC Dataset Classes
2+
3+
Original author: Francisco Massa
4+
https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py
5+
6+
Updated by: Ellis Brown, Max deGroot
7+
"""
8+
9+
import os
10+
import os.path
11+
import sys
12+
import torch
13+
import torch.utils.data as data
14+
import torchvision.transforms as transforms
15+
from PIL import Image, ImageDraw, ImageFont
16+
import cv2
17+
import numpy as np
18+
if sys.version_info[0] == 2:
19+
import xml.etree.cElementTree as ET
20+
else:
21+
import xml.etree.ElementTree as ET
22+
23+
VOC_CLASSES = ( # always index 0
24+
'aeroplane', 'bicycle', 'bird', 'boat',
25+
'bottle', 'bus', 'car', 'cat', 'chair',
26+
'cow', 'diningtable', 'dog', 'horse',
27+
'motorbike', 'person', 'pottedplant',
28+
'sheep', 'sofa', 'train', 'tvmonitor')
29+
30+
# for making bounding boxes pretty
31+
COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128),
32+
(0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128))
33+
34+
35+
class AnnotationTransform(object):
36+
"""Transforms a VOC annotation into a Tensor of bbox coords and label index
37+
Initilized with a dictionary lookup of classnames to indexes
38+
39+
Arguments:
40+
class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
41+
(default: alphabetic indexing of VOC's 20 classes)
42+
keep_difficult (bool, optional): keep difficult instances or not
43+
(default: False)
44+
height (int): height
45+
width (int): width
46+
"""
47+
48+
def __init__(self, class_to_ind=None, keep_difficult=False):
49+
self.class_to_ind = class_to_ind or dict(
50+
zip(VOC_CLASSES, range(len(VOC_CLASSES))))
51+
self.keep_difficult = keep_difficult
52+
53+
def __call__(self, target, width, height):
54+
"""
55+
Arguments:
56+
target (annotation) : the target annotation to be made usable
57+
will be an ET.Element
58+
Returns:
59+
a list containing lists of bounding boxes [bbox coords, class name]
60+
"""
61+
res = []
62+
for obj in target.iter('object'):
63+
difficult = int(obj.find('difficult').text) == 1
64+
if not self.keep_difficult and difficult:
65+
continue
66+
name = obj.find('name').text.lower().strip()
67+
bbox = obj.find('bndbox')
68+
69+
pts = ['xmin', 'ymin', 'xmax', 'ymax']
70+
bndbox = []
71+
for i, pt in enumerate(pts):
72+
cur_pt = int(bbox.find(pt).text) - 1
73+
# scale height or width
74+
cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
75+
bndbox.append(cur_pt)
76+
label_idx = self.class_to_ind[name]
77+
bndbox.append(label_idx)
78+
res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind]
79+
# img_id = target.find('filename').text[:-4]
80+
81+
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
82+
83+
84+
class VOCDetection(data.Dataset):
85+
"""VOC Detection Dataset Object
86+
87+
input is image, target is annotation
88+
89+
Arguments:
90+
root (string): filepath to VOCdevkit folder.
91+
image_set (string): imageset to use (eg. 'train', 'val', 'test')
92+
transform (callable, optional): transformation to perform on the
93+
input image
94+
target_transform (callable, optional): transformation to perform on the
95+
target `annotation`
96+
(eg: take in caption string, return tensor of word indices)
97+
dataset_name (string, optional): which dataset to load
98+
(default: 'VOC2007')
99+
"""
100+
101+
def __init__(self, root, image_sets, transform=None, target_transform=None,
102+
dataset_name='VOC0712'):
103+
self.root = root
104+
self.image_set = image_sets
105+
self.transform = transform
106+
self.target_transform = target_transform
107+
self.name = dataset_name
108+
self._annopath = os.path.join('%s', 'Annotations', '%s.xml')
109+
self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg')
110+
self.ids = list()
111+
for (year, name) in image_sets:
112+
rootpath = os.path.join(self.root, 'VOC' + year)
113+
for line in open(os.path.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
114+
self.ids.append((rootpath, line.strip()))
115+
116+
def __getitem__(self, index):
117+
im, gt, h, w = self.pull_item(index)
118+
119+
return im, gt
120+
121+
def __len__(self):
122+
return len(self.ids)
123+
124+
def pull_item(self, index):
125+
img_id = self.ids[index]
126+
127+
target = ET.parse(self._annopath % img_id).getroot()
128+
img = cv2.imread(self._imgpath % img_id)
129+
height, width, channels = img.shape
130+
131+
if self.target_transform is not None:
132+
target = self.target_transform(target, width, height)
133+
134+
if self.transform is not None:
135+
target = np.array(target)
136+
img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
137+
# to rgb
138+
img = img[:, :, (2, 1, 0)]
139+
# img = img.transpose(2, 0, 1)
140+
target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
141+
return torch.from_numpy(img).permute(2, 0, 1), target, height, width
142+
# return torch.from_numpy(img), target, height, width
143+
144+
def pull_image(self, index):
145+
'''Returns the original image object at index in PIL form
146+
147+
Note: not using self.__getitem__(), as any transformations passed in
148+
could mess up this functionality.
149+
150+
Argument:
151+
index (int): index of img to show
152+
Return:
153+
PIL img
154+
'''
155+
img_id = self.ids[index]
156+
return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
157+
158+
def pull_anno(self, index):
159+
'''Returns the original annotation of image at index
160+
161+
Note: not using self.__getitem__(), as any transformations passed in
162+
could mess up this functionality.
163+
164+
Argument:
165+
index (int): index of img to get annotation of
166+
Return:
167+
list: [img_id, [(label, bbox coords),...]]
168+
eg: ('001718', [('dog', (96, 13, 438, 332))])
169+
'''
170+
img_id = self.ids[index]
171+
anno = ET.parse(self._annopath % img_id).getroot()
172+
gt = self.target_transform(anno, 1, 1)
173+
return img_id[1], gt
174+
175+
def pull_tensor(self, index):
176+
'''Returns the original image at an index in tensor form
177+
178+
Note: not using self.__getitem__(), as any transformations passed in
179+
could mess up this functionality.
180+
181+
Argument:
182+
index (int): index of img to show
183+
Return:
184+
tensorized version of img, squeezed
185+
'''
186+
return torch.Tensor(self.pull_image(index)).unsqueeze_(0)
187+
188+
189+
def detection_collate(batch):
190+
"""Custom collate fn for dealing with batches of images that have a different
191+
number of associated object annotations (bounding boxes).
192+
193+
Arguments:
194+
batch: (tuple) A tuple of tensor images and lists of annotations
195+
196+
Return:
197+
A tuple containing:
198+
1) (tensor) batch of images stacked on their 0 dim
199+
2) (list of tensors) annotations for a given image are stacked on 0 dim
200+
"""
201+
targets = []
202+
imgs = []
203+
for sample in batch:
204+
imgs.append(sample[0])
205+
targets.append(torch.FloatTensor(sample[1]))
206+
return torch.stack(imgs, 0), targets
6.17 MB
Binary file not shown.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .functions import *
2+
from .modules import *

0 commit comments

Comments
 (0)