目录
数据增强tf.keras.preprocessing.image.ImageDataGenerator()
断点续训提取可训练参数可视化准确率上升和损失下降训练代码给图识物
数据增强
tf.keras.preprocessing.image.ImageDataGenerator()
ImageDataGenerator()数据输入维度是四维,如果需要输入数据不是四维,需要先reshape
断点续训
把上次训练好的模型保存起来,然后再执行一遍代码,可以在上次的结果基础上继续寻找最好的。
提取可训练参数
查看保存模型的参数是多少
可视化准确率上升和损失下降
画图代码
acc
= history
.history
['sparse_categorical_accuracy']
val_acc
= history
.history
['val_sparse_categorical_accuracy']
loss
= history
.history
['loss']
val_loss
= history
.history
['val_loss']
plt
.subplot
(1, 2, 1)
plt
.plot
(acc
, label
='Training Accuracy')
plt
.plot
(val_acc
, label
='Validation Accuracy')
plt
.title
('Training and Validation Accuracy')
plt
.legend
()
plt
.subplot
(1, 2, 2)
plt
.plot
(loss
, label
='Training Loss')
plt
.plot
(val_loss
, label
='Validation Loss')
plt
.title
('Training and Validation Loss')
plt
.legend
()
plt
.show
()
训练代码
import tensorflow
as tf
import os
import numpy
as np
from matplotlib
import pyplot
as plt
np
.set_printoptions
(threshold
=np
.inf
)
mnist
= tf
.keras
.datasets
.mnist
(x_train
, y_train
), (x_test
, y_test
) = mnist
.load_data
()
x_train
, x_test
= x_train
/ 255.0, x_test
/ 255.0
model
= tf
.keras
.models
.Sequential
([
tf
.keras
.layers
.Flatten
(),
tf
.keras
.layers
.Dense
(128, activation
='relu'),
tf
.keras
.layers
.Dense
(10, activation
='softmax')
])
model
.compile(optimizer
='adam',
loss
=tf
.keras
.losses
.SparseCategoricalCrossentropy
(from_logits
=False),
metrics
=['sparse_categorical_accuracy'])
checkpoint_save_path
= "./checkpoint/mnist.ckpt"
if os
.path
.exists
(checkpoint_save_path
+ '.index'):
print('-------------load the model-----------------')
model
.load_weights
(checkpoint_save_path
)
cp_callback
= tf
.keras
.callbacks
.ModelCheckpoint
(filepath
=checkpoint_save_path
,
save_weights_only
=True,
save_best_only
=True)
history
= model
.fit
(x_train
, y_train
, batch_size
=32, epochs
=5, validation_data
=(x_test
, y_test
), validation_freq
=1,
callbacks
=[cp_callback
])
model
.summary
()
print(model
.trainable_variables
)
file = open('./weights.txt', 'w')
for v
in model
.trainable_variables
:
file.write
(str(v
.name
) + '\n')
file.write
(str(v
.shape
) + '\n')
file.write
(str(v
.numpy
()) + '\n')
file.close
()
acc
= history
.history
['sparse_categorical_accuracy']
val_acc
= history
.history
['val_sparse_categorical_accuracy']
loss
= history
.history
['loss']
val_loss
= history
.history
['val_loss']
plt
.subplot
(1, 2, 1)
plt
.plot
(acc
, label
='Training Accuracy')
plt
.plot
(val_acc
, label
='Validation Accuracy')
plt
.title
('Training and Validation Accuracy')
plt
.legend
()
plt
.subplot
(1, 2, 2)
plt
.plot
(loss
, label
='Training Loss')
plt
.plot
(val_loss
, label
='Validation Loss')
plt
.title
('Training and Validation Loss')
plt
.legend
()
plt
.show
()
给图识物