def resize_image_bb(read_path,write_path,bb,sz):"""Resize an image and its bounding box and write image to new path"""im = read_image(read_path)im_resized = cv2.resize(im, (int(1.49*sz), sz))Y_resized = cv2.resize(create_mask(bb, im), (int(1.49*sz), sz))new_path = str(write_path/read_path.parts[-1])cv2.imwrite(new_path, cv2.cvtColor(im_resized, cv2.COLOR_RGB2BGR))return new_path, mask_to_bb(Y_resized)
#Populating Training DF with new paths and bounding boxesnew_paths = []new_bbs = []train_path_resized = Path('./road_signs/images_resized')for index, row in df_train.iterrows():new_path,new_bb = resize_image_bb(row['filename'], train_path_resized, create_bb_array(row.values),300)new_paths.append(new_path)new_bbs.append(new_bb)df_train['new_path'] = new_pathsdf_train['new_bb'] = new_bbs
4.数据增强数据增强是一种通过使用现有图像的不同变体创建新的训练图像来更好地概括我们的模型的技术 。我们当前的训练集中只有 800 张图像,因此数据增强对于确保我们的模型不会过拟合非常重要 。
对于这个问题,我使用了翻转、旋转、中心裁剪和随机裁剪 。
这里唯一需要记住的是确保包围盒也以与图像相同的方式进行转换 。
# modified from fast.aidef crop(im, r, c, target_r, target_c):return im[r:r+target_r, c:c+target_c]# random crop to the original sizedef random_crop(x, r_pix=8):""" Returns a random crop"""r, c,*_ = x.shapec_pix = round(r_pix*c/r)rand_r = random.uniform(0, 1)rand_c = random.uniform(0, 1)start_r = np.floor(2*rand_r*r_pix).astype(int)start_c = np.floor(2*rand_c*c_pix).astype(int)return crop(x, start_r, start_c, r-2*r_pix, c-2*c_pix)def center_crop(x, r_pix=8):r, c,*_ = x.shapec_pix = round(r_pix*c/r)return crop(x, r_pix, c_pix, r-2*r_pix, c-2*c_pix)
def rotate_cv(im, deg, y=False, mode=cv2.BORDER_REFLECT, interpolation=cv2.INTER_AREA):""" Rotates an image by deg degrees"""r,c,*_ = im.shapeM = cv2.getRotationMatrix2D((c/2,r/2),deg,1)if y:return cv2.warpAffine(im, M,(c,r), borderMode=cv2.BORDER_CONSTANT)return cv2.warpAffine(im,M,(c,r), borderMode=mode, flags=cv2.WARP_FILL_OUTLIERS+interpolation)def random_cropXY(x, Y, r_pix=8):""" Returns a random crop"""r, c,*_ = x.shapec_pix = round(r_pix*c/r)rand_r = random.uniform(0, 1)rand_c = random.uniform(0, 1)start_r = np.floor(2*rand_r*r_pix).astype(int)start_c = np.floor(2*rand_c*c_pix).astype(int)xx = crop(x, start_r, start_c, r-2*r_pix, c-2*c_pix)YY = crop(Y, start_r, start_c, r-2*r_pix, c-2*c_pix)return xx, YYdef transformsXY(path, bb, transforms):x = cv2.imread(str(path)).astype(np.float32)x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)/255Y = create_mask(bb, x)if transforms:rdeg = (np.random.random()-.50)*20x = rotate_cv(x, rdeg)Y = rotate_cv(Y, rdeg, y=True)if np.random.random() > 0.5:x = np.fliplr(x).copy()Y = np.fliplr(Y).copy()x, Y = random_cropXY(x, Y)else:x, Y = center_crop(x), center_crop(Y)return x, mask_to_bb(Y)
def create_corner_rect(bb, color='red'):bb = np.array(bb, dtype=np.float32)return plt.Rectangle((bb[1], bb[0]), bb[3]-bb[1], bb[2]-bb[0], color=color,fill=False, lw=3)def show_corner_bb(im, bb):plt.imshow(im)plt.gca().add_patch(create_corner_rect(bb))
文章插图
图片
5.PyTorch 数据集现在我们已经有了数据增强,我们可以进行训练验证拆分并创建我们的 PyTorch 数据集 。我们使用 Imag.NET 统计数据对图像进行标准化,因为我们使用的是预训练的 ResNet 模型并在训练时在我们的数据集中应用数据增强 。
X_train, X_val, y_train, y_val = train_test_split(X, Y, test_size=0.2, random_state=42)
def normalize(im):"""Normalizes images with Imagenet stats."""imagenet_stats = np.array([[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]])return (im - imagenet_stats[0])/imagenet_stats[1]
class RoadDataset(Dataset):def __init__(self, paths, bb, y, transforms=False):self.transforms = transformsself.paths = paths.valuesself.bb = bb.valuesself.y = y.valuesdef __len__(self):return len(self.paths)def __getitem__(self, idx):path = self.paths[idx]y_class = self.y[idx]x, y_bb = transformsXY(path, self.bb[idx], self.transforms)x = normalize(x)x = np.rollaxis(x, 2)return x, y_class, y_bb
train_ds = RoadDataset(X_train['new_path'],X_train['new_bb'] ,y_train, transforms=True)valid_ds = RoadDataset(X_val['new_path'],X_val['new_bb'],y_val)
batch_size = 64train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)valid_dl = DataLoader(valid_ds, batch_size=batch_size)
6.PyTorch 模型对于这个模型,我使用了一个非常简单的预先训练的 resNet-34模型 。由于我们有两个任务要完成,这里有两个最后的层: 包围盒回归器和图像分类器 。
推荐阅读
- 16个优秀的开源微信小程序项目,接单赚钱利器!
- 让Java起飞的技术...
- 即将到来的 Vue 3 “Vapor Mode”
- 学会使用Java的远程调试工具,解决难题
- Oracle数据库调优实战:优化SQL查询的黄金法则!
- JVM的调优常用参数
- API请求重试的8种方法,你用哪种?
- 利用Docker简化机器学习应用程序的部署和可扩展性
- 2024年的后端和Web开发趋势
- 警惕“应用推荐”背后的信贷陷阱