数据增强

tech2025-10-30  1

  过拟合的原因是学习样本太少,导致无法训练出能够泛华到新数据的模型。如果拥有无限的数据,那么模型能够观察到数据分布的所有内容,这样就永远不会过拟合。数据增强是从现有的训练样本中生成更多的训练数据,其方法是利用多种能够生成可信图像的随机变换来增加(augment)样本。其目标是,模型在训练时不会两次查看相同的图像。这让模型能够观察到数据的更多内容,从而具有更好的泛华能力。

  在Keras中,这可以通过对ImageDataGenerator实例读取的图像执行多次随机变换来实现。我们先来看一下例子。

datagen = ImageDataGenerator( rotation_range=40, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest')

这里只选择了几个参数。我们来快速了解一下这些参数的含义。

rotation_range是角度值(在0-180范围内),表示图像随机旋转的角度范围。width_shift和height_shift是图像在水平或者垂直方向上平移的范围(相当于总宽度或总高度的比例)。shear_range是随机错切变换的角度。zoom_range是图像随机缩放的范围。horizontal_flip是随机将一半图像水平翻转。如果没有水平不对称的假设(比如真实世界的图像),这种做法是有意义的。fill_mode是用于填充新创建像素的方法,这些新像素可能来自于旋转或宽度/高度平移。 from keras.preprocessing import image fnames = [os.path.join(train_cats_dir,fname) for fname in os.listdir(train_cats_dir)] img_path = fnames[3] img = image.load_img(img_path,target_size=(150,150)) x = image.img_to_array(img) x = x.reshape((1,) + x.shape) i= 0 for batch in datagen.flow(x,batch_size=1): plt.figure(i) imgplot = plt.imshow(image.array_to_img(batch[0])) i += 1 if i % 4 == 0: break plt.show()

如果你使用这种数据增强来训练一个新网络,那么网络将不会两次看到同样的输入。但网络看到的输入仍然是高度相关的,因为这些输入都来自于少量的原始图像。你无法生成新信息,而只能混合现有信息。因此,这种方法可能不足以完全消除过拟合。为了进一步降低过拟合,你还需要向模型中添加一个Dropout层,添加到密集连接分类器之前。

 

model = models.Sequential() model.add(layers.Conv2D(32,(3,3),activation='relu',input_shape=(150,150,3))) model.add(layers.MaxPooling2D((2,2))) model.add(layers.Conv2D(64,(3,3),activation='relu',input_shape=(150,150,3))) model.add(layers.MaxPooling2D((2,2))) model.add(layers.Conv2D(128,(3,3),activation='relu',input_shape=(150,150,3))) model.add(layers.MaxPooling2D((2,2))) model.add(layers.Conv2D(128,(3,3),activation='relu',input_shape=(150,150,3))) model.add(layers.MaxPooling2D((2,2))) model.add(layers.Flatten()) model.add(layers.Dropout(0.5)) model.add(layers.Dense(512,activation='relu')) model.add(layers.Dense(1,activation='sigmoid')) model.compile(loss='binary_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4), metrics=['acc'] )

我们来训练这个使用了数据增强和dropout的网络。

train_datagen = ImageDataGenerator( rescale=1./255, rotation_range=40, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True) test_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory( train_dir, target_size=(150,150), batch_size=32, class_mode='binary') validation_generator = validation_datagen.flow_from_directory( validation_dir, target_size=(150,150), batch_size=32, class_mode='binary') history = model.fit_generator( train_generator, steps_per_epoch=100, epochs=100, validation_data=validation_generator, validation_steps=50)

 

最新回复(0)