Keras Custom Loss Feature

I am working on a class-incremental classifier approach using CNN as a function extractor and a fully connected block for classification.

First, I fine-tuned VGG for each trained network to complete a new task. When the network learns a new task, I keep a few examples for each class so as not to forget when new classes are available.

When some classes are available, I must calculate each output of the instances included in the examples for the new classes. Now adding zeros to the outputs for the old classes and adding a label corresponding to each new class at the output of the new classes, I have my new shortcuts, i.e. if 3 new classes are introduced.

Old type type: [0.1, 0.05, 0.79, ..., 0 0 0]

A new type of type: [0.1, 0.09, 0.3, 0.4, ..., 1 0 0] ** the last outputs correspond to the class.

My question is, how can I change the loss function for a custom one that will train for new classes? The loss function that I want to implement is defined as:

loss function

where the distillation losses correspond to the outputs of the old classes to avoid forgetting, and the loss of classification corresponds to the new classes.

If you can provide me a sample code to change the loss function in keras, it would be nice.

Thanks!!!!!

+38
deep-learning computer-vision keras conv-neural-network loss-function
source share
1 answer

All you have to do is define a function for this, using keras functions for calculations. The function should take the true values ​​and the predicted values ​​of the model.

Now, since I'm not sure what g, q, x, and y are in your function, I’ll just create a basic example here without worrying about what this means or whether it is really a useful function:

 import keras.backend as K def customLoss(yTrue,yPred): return K.sum(K.log(yTrue) - K.log(yPred)) 

All backend functions can be viewed here: https://keras.io/backend/#backend-functions

After that, compile your model using this function instead of the usual one:

 model.compile(loss=customLoss, optimizer = .....) 
+64
source share

All Articles