That would be easiest with a custom training loop.
def reconstruct(colored_inputs):
with tf.GradientTape() as tape:
grayscale_inputs = tf.image.rgb_to_grayscale(colored_inputs)
out = autoencoder(grayscale_inputs)
loss = loss_object(colored_inputs, out)
gradients = tape.gradient(loss, autoencoder.trainable_variables)
optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))
reconstruction_loss(loss)
Here, my data iterator is cyling through all the color pictures, but its converted to grayscale before being passed to the model. Then, the RGB output of the model is compared to the original RGB image. You will have to use the argument class_mode=None
in flow_from_directory
. I used tf.image.rgb_to_grayscale
to make the conversion between grayscale and RGB.
Full example:
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
import os
os.chdir(r'catsanddogs')
generator = tf.keras.preprocessing.image.ImageDataGenerator()
iterator = generator.flow_from_directory(
target_size=(32, 32),
directory='.',
batch_size=4,
class_mode=None)
encoder = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(32, 32, 1)),
tf.keras.layers.Dense(32),
tf.keras.layers.Dense(16)
])
decoder = tf.keras.Sequential([
tf.keras.layers.Dense(32, input_shape=[16]),
tf.keras.layers.Dense(32 * 32 * 3),
tf.keras.layers.Reshape([32, 32, 3])
])
autoencoder = tf.keras.Sequential([encoder, decoder])
loss_object = tf.losses.BinaryCrossentropy()
reconstruction_loss = tf.metrics.Mean(name='reconstruction_loss')
optimizer = tf.optimizers.Adam()
def reconstruct(colored_inputs):
with tf.GradientTape() as tape:
grayscale_inputs = tf.image.rgb_to_grayscale(colored_inputs)
out = autoencoder(grayscale_inputs)
loss = loss_object(colored_inputs, out)
gradients = tape.gradient(loss, autoencoder.trainable_variables)
optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))
reconstruction_loss(loss)
if __name__ == '__main__':
template = 'Epoch {:2} Reconstruction Loss {:.4f}'
for epoch in range(50):
reconstruction_loss.reset_states()
for input_batches in iterator:
reconstruct(input_batches)
print(template.format(epoch + 1, reconstruction_loss.result()))
CLICK HERE to find out more related problems solutions.