Skip to content

Concatenate two models with tensorflow.keras

I’m currently studying neural network models for image analysis, with the MNIST dataset. I first used only the image to build a first model. Then I created a additionnal variable, which is : 0 when the digit is actually between 0 and 4, and 1 when it’s greater or equal than 5.

Therefore, I want to build a model that can take these two informations : the image of the digit, and that additionnal variable I juste created.

I created the two first models, one for the image and one for the exogenous variable, as follow :

import tensorflow as tf
from tensorflow import keras
image_model = keras.models.Sequential()
#First conv layer :
image_model.add( keras.layers.Conv2D( 64, kernel_size=3,
                                      input_shape=(28, 28, 1) ) )
#Second conv layer :
image_model.add( keras.layers.Conv2D( 32, kernel_size=3, activation=keras.activations.relu ) )
#Flatten layer :
image_model.add( keras.layers.Flatten() )
print( image_model.summary(), 'n' )
info_model = keras.models.Sequential()
info_model.add( keras.layers.Dense( 5, activation=keras.activations.relu, input_shape=(1,) ) )
print( info_model.summary() )

Then I would like to concatenate both final layers, to finally put another dense layer with softmax to predict class probabilities.

I know it’s feasible using Keras functionnal API, but how could one do it using tf.keras ?


You can easily use Keras’ functional API in TF (tested with TF 2.0):

import tensorflow as tf
# Image
input_1 = tf.keras.layers.Input(shape=(28, 28, 1))
conv2d_1 = tf.keras.layers.Conv2D(64, kernel_size=3,
# Second conv layer :
conv2d_2 = tf.keras.layers.Conv2D(32, kernel_size=3,
# Flatten layer :
flatten = tf.keras.layers.Flatten()(conv2d_2)
# The other input
input_2 = tf.keras.layers.Input(shape=(1,))
dense_2 = tf.keras.layers.Dense(5, activation=tf.keras.activations.relu)(input_2)
# Concatenate
concat = tf.keras.layers.Concatenate()([flatten, dense_2])
n_classes = 4
# output layer
output = tf.keras.layers.Dense(units=n_classes,
full_model = tf.keras.Model(inputs=[input_1, input_2], outputs=[output])

Which gives you the model you are looking for.