How to set parameters in keras to be non-trainable?
You can simple assign a boolean value to the layer property trainable
.
model.layers[n].trainable = False
You can visualize which layer is trainable:
for l in model.layers:
print(l.name, l.trainable)
You can pass it by the model definition too:
frozen_layer = Dense(32, trainable=False)
From Keras documentation:
To "freeze" a layer means to exclude it from training, i.e. its weights will never be updated. This is useful in the context of fine-tuning a model, or using fixed embeddings for a text input.
You can pass a trainable argument (boolean) to a layer constructor to set a layer to be non-trainable. Additionally, you can set the trainable property of a layer to True or False after instantiation. For this to take effect, you will need to call compile() on your model after modifying the trainable property.
There is a typo in the Word "trainble"(missing an "a"). Saddly keras doesn't warn me that the model doesn't have the property "trainble". The question could be closed.
Despite the fact that the original question's solution is a typo fix, let me add some information on keras trainables.
Modern Keras contains the following facilities to view and manipulate trainable state:
tf.keras.Layer._get_trainable_state()
function - prints the dictinary where keys are model components and values are booleans. Note thattf.keras.Model
is also atf.Keras.Layer
.tf.keras.Layer.trainable
property - to manipulate trainable state of individual layers.
So the typical actions look like following:
# Print current trainable map:
print(model._get_trainable_state())
# Set every layer to be non-trainable:
for k,v in model._get_trainable_state().items():
k.trainable = False
# Don't forget to re-compile the model
model.compile(...)
Change the last 3 lines in your code:
last_few_layers = 20 #number of the last few layers to freeze
self.domain_regressor = Model(img_inputs, domain_label)
for layer in model.layers[:-last_few_layers]:
layer.trainable = False
self.domain_regressor.compile(optimizer = opt, loss='binary_crossentropy', metrics=['accuracy'])