@@ -196,11 +196,13 @@ class TorchModel(BaseModel, ExtractionMixin, OptimalBatchSizeMixin, Visualizatio
196
196
trainable : sequence, optional
197
197
Names of model parts to train. Should be a subset of names in `order` and can be used to freeze parameters.
198
198
199
- init_weights : callable, 'best_practice_resnet', or None
199
+ init_weights : callable, 'best_practice_resnet', tuple, sequence of them or None
200
200
Model weights initialization.
201
201
If None, then default initialization is used.
202
202
If 'best_practice_resnet', then common used non-default initialization is used.
203
203
If callable, then callable applied to each layer.
204
+ If tuple, then the first element should be of the types above, and the second defines a model part to apply on.
205
+ If sequence, then each element should be of the types defined above: applies all init functions sequentially.
204
206
205
207
Examples:
206
208
@@ -212,6 +214,7 @@ def callable_init(module): # example of a callable for init
212
214
nn.kaiming_normal_(module.weight)
213
215
214
216
config = {'init_weights': callable_init}
217
+ - ``{'init_weights': ('best_practice_resnet', 'body')}`` # applies only at `body` module
215
218
216
219
217
220
# Shapes: optional
@@ -854,7 +857,8 @@ def build_model(self, inputs=None):
854
857
inputs = self .make_placeholder_data (to_device = True )
855
858
856
859
if 'model' not in self .config :
857
- self .model = Network (inputs = inputs , config = self .config , device = self .device )
860
+ with torch .no_grad ():
861
+ self .model = Network (inputs = inputs , config = self .config , device = self .device )
858
862
else :
859
863
self .model = self .config ['model' ]
860
864
@@ -902,12 +906,20 @@ def initialize_weights(self):
902
906
# Parse model weights initialization
903
907
init_weights = init_weights if isinstance (init_weights , list ) else [init_weights ]
904
908
905
- for init_weights_function in init_weights :
909
+ for init_weights_ in init_weights :
910
+ if isinstance (init_weights_ , tuple ) and len (init_weights_ ) == 2 :
911
+ init_weights_function , init_weights_module = init_weights_
912
+ else :
913
+ init_weights_function , init_weights_module = init_weights_ , None
914
+
906
915
if init_weights_function in {'resnet' , 'classic' }:
907
916
init_weights_function = best_practice_resnet_init
908
917
909
918
# Actual weights initialization
910
- self .model .apply (init_weights_function )
919
+ if init_weights_module is None :
920
+ self .model .apply (init_weights_function )
921
+ else :
922
+ getattr (self .model , init_weights_module ).apply (init_weights_function )
911
923
912
924
913
925
# Transfer to/from device(s)
@@ -1010,6 +1022,10 @@ def train(self, inputs, targets, outputs=None, mode='train', lock=True, profile=
1010
1022
with the same keys and requested tensors as values.
1011
1023
lock : bool
1012
1024
If True, then model, loss and gradient update operations are locked, thus allowing for multithreading.
1025
+ mode : None, str or callable
1026
+ If None, then does nothing.
1027
+ If str, then identifies mode to put the model in: one of ``'train'`` or ``'eval'``.
1028
+ If callable, then applied to the model directly.
1013
1029
sync_frequency : int, bool or None
1014
1030
If int, then how often to apply accumulated gradients to the weights.
1015
1031
If True, then value from config is used.
@@ -1336,6 +1352,10 @@ def predict(self, inputs, targets=None, outputs=None, lock=True, microbatch_size
1336
1352
amp : None or bool
1337
1353
If None, then use amp setting from config.
1338
1354
If bool, then overrides the amp setting for prediction.
1355
+ mode : None, str or callable
1356
+ If None, then does nothing.
1357
+ If str, then identifies mode to put the model in: one of ``'train'`` or ``'eval'``.
1358
+ If callable, then applied to the model directly.
1339
1359
no_grad : bool
1340
1360
Whether to disable gradient computation during model evaluation.
1341
1361
transfer_from_device : bool
@@ -1476,11 +1496,15 @@ def __call__(self, inputs, targets=None, outputs='predictions', lock=True,
1476
1496
1477
1497
# Common utilities for train and predict
1478
1498
def set_model_mode (self , mode ):
1479
- """ Set model mode to either train or eval. """
1499
+ """ Set model mode to either train or eval. If provided with a callable, applies it to the model directly. """
1480
1500
if mode in {'train' , 'training' }:
1481
1501
self .model .train ()
1482
1502
elif mode in {'eval' , 'predict' , 'inference' }:
1483
1503
self .model .eval ()
1504
+ elif mode is None :
1505
+ pass
1506
+ elif callable (mode ):
1507
+ self .model .apply (mode )
1484
1508
else :
1485
1509
raise ValueError (f'Unknown model mode={ mode } ' )
1486
1510
0 commit comments