You should be able to do the following:

```
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.jacobian(loss, trainable_vars)
new_gradients = []
for grad in gradients:
new_grad = do_something_to(grad)
new_gradients.append(new_grad)
# Update weights
self.optimizer.apply_gradients(zip(new_gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
```

Some important notes: `loss`

returned by the `compiled_loss`

function must not average over the batch axis, i.e. I’m assuming it is a tensor of shape `(batch_size, )`

, *not* a scalar.

This will cause the jacobian to return gradients of the shape `(batch_size, ) + variable_shape`

, that is, you now have per-batch-element gradients. You can now manipulate these gradients however you want, and should at some point get rid of the additional batch axis of course (e.g. averaging). That is, `new_grad`

should have the same shape as the corresponding variable.

Regarding your last comment: As I mentioned, the loss function indeed needs to return one loss per data point, i.e. must not average over the batch. However, this is not enough because if you were to give this vector to `tape.gradient`

, the gradient function will simply sum up the loss values (since it only works with scalars). This is why `jacobian`

is necessary.

Finally, `jacobian`

can be very slow. In the worst case, run time may be multiplied by batch size because it needs to compute that many separate gradients. However, this is done in parallel to some degree so the slowdown might not be as bad.

CLICK HERE to find out more related problems solutions.