How to override gradient vector calculation method for optimization algos in Keras, Tensorflow?

You should be able to do the following:

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.jacobian(loss, trainable_vars)

        new_gradients = []
        for grad in gradients:
            new_grad = do_something_to(grad)
            new_gradients.append(new_grad)

        # Update weights
        self.optimizer.apply_gradients(zip(new_gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

Some important notes: loss returned by the compiled_loss function must not average over the batch axis, i.e. I’m assuming it is a tensor of shape (batch_size, ), not a scalar.
This will cause the jacobian to return gradients of the shape (batch_size, ) + variable_shape, that is, you now have per-batch-element gradients. You can now manipulate these gradients however you want, and should at some point get rid of the additional batch axis of course (e.g. averaging). That is, new_grad should have the same shape as the corresponding variable.

Regarding your last comment: As I mentioned, the loss function indeed needs to return one loss per data point, i.e. must not average over the batch. However, this is not enough because if you were to give this vector to tape.gradient, the gradient function will simply sum up the loss values (since it only works with scalars). This is why jacobian is necessary.

Finally, jacobian can be very slow. In the worst case, run time may be multiplied by batch size because it needs to compute that many separate gradients. However, this is done in parallel to some degree so the slowdown might not be as bad.

CLICK HERE to find out more related problems solutions.

Leave a Comment

Your email address will not be published.

Scroll to Top