That would be easiest with a custom training loop.

def reconstruct(colored_inputs):
    with tf.GradientTape() as tape:
        grayscale_inputs = tf.image.rgb_to_grayscale(colored_inputs)

        out = autoencoder(grayscale_inputs)
        loss = loss_object(colored_inputs, out)

    gradients = tape.gradient(loss, autoencoder.trainable_variables)
    optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))

    reconstruction_loss(loss)

Here, my data iterator is cyling through all the color pictures, but its converted to grayscale before being passed to the model. Then, the RGB output of the model is compared to the original RGB image. You will have to use the argument class_mode=None in flow_from_directory. I used tf.image.rgb_to_grayscale to make the conversion between grayscale and RGB.

Full example:

import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
import os

os.chdir(r'catsanddogs')

generator = tf.keras.preprocessing.image.ImageDataGenerator()
iterator = generator.flow_from_directory(
    target_size=(32, 32),
    directory='.',
    batch_size=4,
    class_mode=None)

encoder = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(32, 32, 1)),
    tf.keras.layers.Dense(32),
    tf.keras.layers.Dense(16)
])

decoder = tf.keras.Sequential([
    tf.keras.layers.Dense(32, input_shape=[16]),
    tf.keras.layers.Dense(32 * 32 * 3),
    tf.keras.layers.Reshape([32, 32, 3])
])


autoencoder = tf.keras.Sequential([encoder, decoder])

loss_object = tf.losses.BinaryCrossentropy()

reconstruction_loss = tf.metrics.Mean(name='reconstruction_loss')

optimizer = tf.optimizers.Adam()


def reconstruct(colored_inputs):
    with tf.GradientTape() as tape:
        grayscale_inputs = tf.image.rgb_to_grayscale(colored_inputs)

        out = autoencoder(grayscale_inputs)
        loss = loss_object(colored_inputs, out)

    gradients = tape.gradient(loss, autoencoder.trainable_variables)
    optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))

    reconstruction_loss(loss)


if __name__ == '__main__':
    template = 'Epoch {:2} Reconstruction Loss {:.4f}'
    for epoch in range(50):
        reconstruction_loss.reset_states()
        for input_batches in iterator:
            reconstruct(input_batches)
        print(template.format(epoch + 1, reconstruction_loss.result()))

CLICK HERE to find out more related problems solutions.

Leave a Comment

Your email address will not be published.

Scroll to Top