我正在尝试通过在输入中添加四个数字来重新构建一个用于玩井字游戏的网络。那些。事实证明,输入是一个字段和四个附加数字。调用Concatenate方法时,执行崩溃:
A `Concatenate` layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 5, 5), (None, 4)]
告诉我如何修复这个错误。
网:
class TicTacToeNNet():
def __init__(self, game, args):
# game params
self.board_x, self.board_y = game.getBoardSize()
self.action_size = game.getActionSize()
self.args = args
# Neural Net
input1 = Input(shape=(self.board_x, self.board_y)) # s: batch_size x board_x x board_y
input2 = Input(shape=(4,))
merged = Concatenate(axis=1)([input1, input2])
h_conv1 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(args.num_channels, 3, padding='same')(merged))) # batch_size x board_x x board_y x num_channels
h_conv2 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(args.num_channels, 3, padding='same')(h_conv1))) # batch_size x board_x x board_y x num_channels
h_conv3 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(args.num_channels, 3, padding='same')(h_conv2))) # batch_size x (board_x) x (board_y) x num_channels
h_conv4 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(args.num_channels, 3, padding='valid')(h_conv3))) # batch_size x (board_x-2) x (board_y-2) x num_channels
h_conv4_flat = Flatten()(h_conv4)
s_fc1 = Dropout(args.dropout)(Activation('relu')(BatchNormalization(axis=1)(Dense(1024)(h_conv4_flat)))) # batch_size x 1024
s_fc2 = Dropout(args.dropout)(Activation('relu')(BatchNormalization(axis=1)(Dense(512)(s_fc1)))) # batch_size x 1024
self.pi = Dense(self.action_size, activation='softmax', name='pi')(s_fc2) # batch_size x self.action_size
self.v = Dense(1, activation='tanh', name='v')(s_fc2) # batch_size x 1
self.model = Model(inputs=self.input_boards, outputs=[self.pi, self.v])
self.model.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer=Adam(args.lr))
好吧,你有一个 5x5 矩阵和一个大小为 4 的向量。你想如何将它们相加?
当然,还有一些选择。您可以立即创建
Flatten
矩阵并添加向量。但随后所有层
Conv2D
都会立即停止工作,因为最终得到长度为 29 的向量,并且卷积层Conv2D
与矩阵一起工作,在数据中寻找二维模式。将矩阵转换为向量时,其中一些信息将丢失。我建议最好在层之后的某个地方进行合并,您
Conv2D
将转换为向量,也就是说,在层之后Flatten
甚至更多Dense
层之后的某个地方Dropout
,由您决定它会如何最好。如果这 4 个数字与矩阵完全不同,那么我建议在靠近输出层的地方进行串联。一般来说,在任何情况下,取不再是矩阵而是向量的任何层,即从该层开始h_conv4_flat
,并在某处进行串联。