RError.com

RError.com Logo RError.com Logo

RError.com Navigation

  • 主页

Mobile menu

Close
  • 主页
  • 系统&网络
    • 热门问题
    • 最新问题
    • 标签
  • Ubuntu
    • 热门问题
    • 最新问题
    • 标签
  • 帮助
主页 / 问题 / 918915
Accepted
Семен Романов
Семен Романов
Asked:2020-12-12 16:53:49 +0000 UTC2020-12-12 16:53:49 +0000 UTC 2020-12-12 16:53:49 +0000 UTC

如何使用在 Keras 生成器上训练的神经网络?

  • 772

在发电机上训练 keras 模型,现在如果我这样做

model.predict_generator(test_generator)

那么一切正常,但是我现在如何将其应用于图像?我需要正确地将图像转换为数组(3种颜色)并调用,下面的示例不起作用(opencv中的图像)

data = img.astype(float)/255
model.predict(data)
#ValueError('Error when checking input: expected vgg16_input to have 4 dimensions, but got array with shape (32, 32, 3)',)

或者将打开的图像转换为生成器,但是怎么做呢?

发电机整形:

def Get_generator_data(dir,img_width, img_height, batch_size):
    datagen = ImageDataGenerator(rescale=1. / 255)
    train_generator = datagen.flow_from_directory(
    dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False)
    return train_generator

模型创建

def create_model(outNeron, size):


    # Загружаем предварительно обученную нейронную сеть VGG16
    vgg16_net = VGG16(weights='imagenet', include_top=False, 
                      input_shape=(size, size, 3))

    # "Замораживаем" веса предварительно обученной нейронной сети VGG16
    vgg16_net.trainable = False

    # Создаем составную нейронную сеть на основе VGG16
    # Создаем последовательную модель Keras
    model = Sequential()
    # Добавляем в модель сеть VGG16 вместо слоя
    model.add(vgg16_net)
    # Добавляем в модель новый классификатор
    model.add(Flatten())
    model.add(Dense(256))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(outNeron))
    model.add(Activation('softmax'))
    return model

模型训练

def Train_Model (model, train_generator, nb_train_samples, batch_size, val_generator, nb_validation_samples):
        сheckpoint = ModelCheckpoint('save/mnist-dense.hdf5', 
            monitor='val_acc', 
            save_best_only=True)
    # Компилируем составную нейронную сеть
        model.compile(loss=losses.categorical_crossentropy,
            optimizer=Adam(lr=1e-5), 
            metrics=['accuracy'])
    # Обучаем модель с использованием генераторов
        model.fit_generator(
            train_generator,
            steps_per_epoch=nb_train_samples // batch_size,
            epochs=35,
            validation_data=val_generator,
            validation_steps=nb_validation_samples // batch_size,
            callbacks=[сheckpoint])

模型打印 print(model.summary())

Layer (type)                 Output Shape              Param #
=================================================================
vgg16 (Model)                (None, 1, 1, 512)         14714688
_________________________________________________________________
flatten (Flatten)            (None, 512)               0
_________________________________________________________________
dense (Dense)                (None, 256)               131328
_________________________________________________________________
activation (Activation)      (None, 256)               0
_________________________________________________________________
dropout (Dropout)            (None, 256)               0
_________________________________________________________________
dense_1 (Dense)              (None, 4)                 1028
_________________________________________________________________
activation_1 (Activation)    (None, 4)                 0
=================================================================
Total params: 14,847,044
Trainable params: 132,356
Non-trainable params: 14,714,688
_________________________________________________________________
None
python-3.x
  • 1 1 个回答
  • 10 Views

1 个回答

  • Voted
  1. Best Answer
    MaxU - stop genocide of UA
    2020-12-13T17:37:59Z2020-12-13T17:37:59Z

    尝试将所有图像读入一个 4D 数组并“输入”这个数组model.predict()。

    from pathlib import Path
    from skimage.transform import resize as sk_resize
    from skimage.io import imread as sk_imread
    
    
    def read_image(fn, img_width=32, img_height=32, channels=3,
                   mode='reflect', anti_aliasing=True):
        return sk_resize(sk_imread(fn),
                         output_shape=(img_width, img_height, channels),
                         mode=mode, anti_aliasing=anti_aliasing)
    
    def read_images(files, img_width=32, img_height=32, channels=3):
        return np.array([read_image(f, img_width, img_height, channels)
                         for f in files])
    
    size=32
    path = Path(r'C:\download')
    imgs4D = read_images([str(f) for f in path.glob('*.jpg')],
                         size, size, 3)
    
    predictions = model.predict(imgs4D)
    
    • 2

相关问题

  • 如何在 tkinter 库中为 python 编程语言中的按钮制作不同的字体?

  • 通过 selenium webdriver 在打开的浏览器中删除通知窗口

  • 编写一个函数,找出所有 5 位数字等于输入值的数字之和

  • PyQt5 缺少参数错误

  • 将参数列表传递给类构造函数

  • 循环跟踪的迭代

Sidebar

Stats

  • 问题 10021
  • Answers 30001
  • 最佳答案 8000
  • 用户 6900
  • 常问
  • 回答
  • Marko Smith

    是否可以在 C++ 中继承类 <---> 结构?

    • 2 个回答
  • Marko Smith

    这种神经网络架构适合文本分类吗?

    • 1 个回答
  • Marko Smith

    为什么分配的工作方式不同?

    • 3 个回答
  • Marko Smith

    控制台中的光标坐标

    • 1 个回答
  • Marko Smith

    如何在 C++ 中删除类的实例?

    • 4 个回答
  • Marko Smith

    点是否属于线段的问题

    • 2 个回答
  • Marko Smith

    json结构错误

    • 1 个回答
  • Marko Smith

    ServiceWorker 中的“获取”事件

    • 1 个回答
  • Marko Smith

    c ++控制台应用程序exe文件[重复]

    • 1 个回答
  • Marko Smith

    按多列从sql表中选择

    • 1 个回答
  • Martin Hope
    Alexandr_TT 圣诞树动画 2020-12-23 00:38:08 +0000 UTC
  • Martin Hope
    Suvitruf - Andrei Apanasik 什么是空? 2020-08-21 01:48:09 +0000 UTC
  • Martin Hope
    Air 究竟是什么标识了网站访问者? 2020-11-03 15:49:20 +0000 UTC
  • Martin Hope
    Qwertiy 号码显示 9223372036854775807 2020-07-11 18:16:49 +0000 UTC
  • Martin Hope
    user216109 如何为黑客设下陷阱,或充分击退攻击? 2020-05-10 02:22:52 +0000 UTC
  • Martin Hope
    Qwertiy 并变成3个无穷大 2020-11-06 07:15:57 +0000 UTC
  • Martin Hope
    koks_rs 什么是样板代码? 2020-10-27 15:43:19 +0000 UTC
  • Martin Hope
    Sirop4ik 向 git 提交发布的正确方法是什么? 2020-10-05 00:02:00 +0000 UTC
  • Martin Hope
    faoxis 为什么在这么多示例中函数都称为 foo? 2020-08-15 04:42:49 +0000 UTC
  • Martin Hope
    Pavel Mayorov 如何从事件或回调函数中返回值?或者至少等他们完成。 2020-08-11 16:49:28 +0000 UTC

热门标签

javascript python java php c# c++ html android jquery mysql

Explore

  • 主页
  • 问题
    • 热门问题
    • 最新问题
  • 标签
  • 帮助

Footer

RError.com

关于我们

  • 关于我们
  • 联系我们

Legal Stuff

  • Privacy Policy

帮助

© 2023 RError.com All Rights Reserve   沪ICP备12040472号-5