语法:
fit(object, x = NULL, y = NULL, batch_size = NULL, epochs = 10, verbose = getOption("keras.fit_verbose", default = 1), callbacks = NULL, view_metrics = getOption("keras.view_metrics", default = "auto"), validation_split = 0, validation_data = NULL, shuffle = TRUE, class_weight = NULL, sample_weight = NULL, initial_epoch = 0, steps_per_epoch = NULL, validation_steps = NULL, ...)特点: (1)全部的数据在RAM中;
语法:
fit_generator(object, generator, steps_per_epoch, epochs = 1, verbose = getOption("keras.fit_verbose", default = 1), callbacks = NULL, view_metrics = getOption("keras.view_metrics", default = "auto"), validation_data = NULL, validation_steps = NULL, class_weight = NULL, max_queue_size = 10, workers = 1, initial_epoch = 0)特点: (1)批次进行训练 关键的是定义合适的generator,比如需要从文件中读取数据:
def generate_arrays_from_file(path): while 1: f = open(path) for line in f: # create numpy arrays of input data # and labels, from each line in the file x, y = process_line(line) img = load_images(x) yield (img, y) f.close() model.fit_generator(generate_arrays_from_file('/my_file.txt'), samples_per_epoch=10000, nb_epoch=10)特点 (1)灵活,只对一个batch的数据进行前向传播
参考:
how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial;geeksforgeeks fit and fit_generator;