Skip to content

Commit 0a81d5b

Browse files
Merge pull request #760 from analysiscenter/improvements
Improvements
2 parents 3446ceb + 9bb89eb commit 0a81d5b

File tree

3 files changed

+95
-12
lines changed

3 files changed

+95
-12
lines changed

batchflow/models/torch/base.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,13 @@ class TorchModel(BaseModel, ExtractionMixin, OptimalBatchSizeMixin, Visualizatio
196196
trainable : sequence, optional
197197
Names of model parts to train. Should be a subset of names in `order` and can be used to freeze parameters.
198198
199-
init_weights : callable, 'best_practice_resnet', or None
199+
init_weights : callable, 'best_practice_resnet', tuple, sequence of them or None
200200
Model weights initialization.
201201
If None, then default initialization is used.
202202
If 'best_practice_resnet', then common used non-default initialization is used.
203203
If callable, then callable applied to each layer.
204+
If tuple, then the first element should be of the types above, and the second defines a model part to apply on.
205+
If sequence, then each element should be of the types defined above: applies all init functions sequentially.
204206
205207
Examples:
206208
@@ -212,6 +214,7 @@ def callable_init(module): # example of a callable for init
212214
nn.kaiming_normal_(module.weight)
213215
214216
config = {'init_weights': callable_init}
217+
- ``{'init_weights': ('best_practice_resnet', 'body')}`` # applies only at `body` module
215218
216219
217220
# Shapes: optional
@@ -854,7 +857,8 @@ def build_model(self, inputs=None):
854857
inputs = self.make_placeholder_data(to_device=True)
855858

856859
if 'model' not in self.config:
857-
self.model = Network(inputs=inputs, config=self.config, device=self.device)
860+
with torch.no_grad():
861+
self.model = Network(inputs=inputs, config=self.config, device=self.device)
858862
else:
859863
self.model = self.config['model']
860864

@@ -902,12 +906,20 @@ def initialize_weights(self):
902906
# Parse model weights initialization
903907
init_weights = init_weights if isinstance(init_weights, list) else [init_weights]
904908

905-
for init_weights_function in init_weights:
909+
for init_weights_ in init_weights:
910+
if isinstance(init_weights_, tuple) and len(init_weights_) == 2:
911+
init_weights_function, init_weights_module = init_weights_
912+
else:
913+
init_weights_function, init_weights_module = init_weights_, None
914+
906915
if init_weights_function in {'resnet', 'classic'}:
907916
init_weights_function = best_practice_resnet_init
908917

909918
# Actual weights initialization
910-
self.model.apply(init_weights_function)
919+
if init_weights_module is None:
920+
self.model.apply(init_weights_function)
921+
else:
922+
getattr(self.model, init_weights_module).apply(init_weights_function)
911923

912924

913925
# Transfer to/from device(s)
@@ -1010,6 +1022,10 @@ def train(self, inputs, targets, outputs=None, mode='train', lock=True, profile=
10101022
with the same keys and requested tensors as values.
10111023
lock : bool
10121024
If True, then model, loss and gradient update operations are locked, thus allowing for multithreading.
1025+
mode : None, str or callable
1026+
If None, then does nothing.
1027+
If str, then identifies mode to put the model in: one of ``'train'`` or ``'eval'``.
1028+
If callable, then applied to the model directly.
10131029
sync_frequency : int, bool or None
10141030
If int, then how often to apply accumulated gradients to the weights.
10151031
If True, then value from config is used.
@@ -1336,6 +1352,10 @@ def predict(self, inputs, targets=None, outputs=None, lock=True, microbatch_size
13361352
amp : None or bool
13371353
If None, then use amp setting from config.
13381354
If bool, then overrides the amp setting for prediction.
1355+
mode : None, str or callable
1356+
If None, then does nothing.
1357+
If str, then identifies mode to put the model in: one of ``'train'`` or ``'eval'``.
1358+
If callable, then applied to the model directly.
13391359
no_grad : bool
13401360
Whether to disable gradient computation during model evaluation.
13411361
transfer_from_device : bool
@@ -1476,11 +1496,15 @@ def __call__(self, inputs, targets=None, outputs='predictions', lock=True,
14761496

14771497
# Common utilities for train and predict
14781498
def set_model_mode(self, mode):
1479-
""" Set model mode to either train or eval. """
1499+
""" Set model mode to either train or eval. If provided with a callable, applies it to the model directly. """
14801500
if mode in {'train', 'training'}:
14811501
self.model.train()
14821502
elif mode in {'eval', 'predict', 'inference'}:
14831503
self.model.eval()
1504+
elif mode is None:
1505+
pass
1506+
elif callable(mode):
1507+
self.model.apply(mode)
14841508
else:
14851509
raise ValueError(f'Unknown model mode={mode}')
14861510

batchflow/named_expr.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,12 +343,15 @@ def get(self, **kwargs):
343343
a = eval_expr(self.a, _call=False, **kwargs)
344344
else:
345345
a = eval_expr(self.a, **kwargs)
346-
b = eval_expr(self.b, **kwargs)
347-
c = eval_expr(self.c, **kwargs)
346+
348347
if self.op in UNARY_OPS:
349348
return OPERATIONS[self.op](a)
349+
350+
b = eval_expr(self.b, **kwargs)
350351
if self.op in BINARY_OPS:
351352
return OPERATIONS[self.op](a, b)
353+
354+
c = eval_expr(self.c, **kwargs)
352355
return OPERATIONS[self.op](a, b, c)
353356

354357
def assign(self, value, **kwargs):
@@ -401,6 +404,58 @@ def __repr__(self):
401404
return 'Unknown expression'
402405

403406

407+
class IF(NamedExpression):
408+
""" Select either ``true`` or ``false``, based on ``condition``.
409+
Useful for simple variables that change along the run of a pipeline.
410+
411+
Examples
412+
--------
413+
Select model mode based on the current pipeline iteration::
414+
mode = IF(condition=I.current<450, true='train', false='eval')
415+
416+
Train the last 20% with larger batch size::
417+
batch_size = IF(condition=I.ratio > 0.8, true=256, false=128)
418+
419+
Notes
420+
-----
421+
An alternative to this named expression is to use ``F``::
422+
def select_batch_size(ratio):
423+
return 256 if ratio > 0.8 else 128
424+
425+
batch_size = F(select_batch_size)(ratio=I.ratio)
426+
427+
Or, with a lambda::
428+
batch_size = F(lambda ratio: 256 if ratio > 0.8 else 128)(ratio=I.ratio)
429+
430+
``F`` is recommended where more flexibility is needed, and ``IF`` can be used for simple binary choices.
431+
"""
432+
def __init__(self, condition, true, false, mode='w', **kwargs):
433+
super().__init__('#!__if__', mode=mode, **kwargs)
434+
self.condition = condition
435+
self.true = true
436+
self.false = false
437+
438+
def get(self, **kwargs):
439+
""" Select based on condition. """
440+
condition = eval_expr(self.condition, **kwargs)
441+
442+
if bool(condition):
443+
return eval_expr(self.true, **kwargs)
444+
return eval_expr(self.false, **kwargs)
445+
446+
447+
def assign(self, value, **kwargs):
448+
""" Assign a value to a named expression, based on condition. """
449+
_, kwargs = self._get_params(**kwargs)
450+
451+
condition = eval_expr(self.condition, **kwargs)
452+
453+
if bool(condition):
454+
self.true.assign(value, **kwargs)
455+
else:
456+
self.false.assign(value, **kwargs)
457+
458+
404459
class B(NamedExpression):
405460
""" Batch component or attribute name
406461

batchflow/notifier.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class Notifier:
9393
If str, then either registered monitor identifiers or names of pipeline variables.
9494
Named expressions are evaluated with the pipeline.
9595
If callable, then it is used to retrieve the container with data.
96+
Should accept variable named arguments in the signature and may return ``None`` value to disable its plotting.
9697
If sequence, then it is used as the container with data.
9798
If dict, then 'source' key should be one of the above to identify container.
9899
Other available keys:
@@ -181,7 +182,7 @@ def __init__(self, bar='a', disable=False, frequency=1, monitors=None, graphs=No
181182
elif isinstance(source, str):
182183
container['name'] = source
183184
elif callable(source):
184-
container['name'] = '<unknown_callable>'
185+
container['name'] = source.__name__
185186
else:
186187
container['name'] = '<unknown_container>'
187188

@@ -364,7 +365,7 @@ def update_data(self, pipeline=None, batch=None):
364365
container['data'] = value
365366

366367
elif callable(source):
367-
container['data'] = source()
368+
container['data'] = source(container=container, notifier=self, pipeline=pipeline, batch=batch)
368369

369370
else:
370371
raise TypeError(f'Unknown type of `source`, {type(source)}!')
@@ -382,7 +383,7 @@ def make_plotter(self, num_graphs=None, layout='horizontal', figsize=None, ncols
382383
""" Make canvas for plotting graphs. """
383384
from .plotter import plot
384385
if num_graphs is None:
385-
num_graphs = len(self.data_containers)
386+
num_graphs = sum(container['data'] is not None for container in self.data_containers)
386387

387388
if ncols is None and nrows is None:
388389
if layout in ['h', 'horizontal']:
@@ -430,7 +431,8 @@ def update_plot(self, index=0, add_suptitle=False, savepath=None, clear_display=
430431
self.plotter.config['suptitle'] = self.bar.format_meter(**fmt)
431432
self.plotter.annotate()
432433

433-
for i, container in enumerate(self.data_containers):
434+
data_containers = [container for container in self.data_containers if container['data'] is not None]
435+
for i, container in enumerate(data_containers):
434436
if i >= index:
435437
subplot_index = i - index
436438
subplot_config = plot_config.maybe_index(subplot_index)
@@ -461,7 +463,7 @@ def update_subplot(self, container, index, **kwargs):
461463
plot_config = container.get('plot_config', {})
462464
plot_config = {**plot_config, **kwargs}
463465

464-
x = np.arange(len(data))
466+
x = np.arange(len(data)) if hasattr(data, '__len__') else None
465467
y = data
466468
if self.slice not in [None, slice(None)]:
467469
x = np.array(x)[self.slice]
@@ -471,6 +473,8 @@ def update_subplot(self, container, index, **kwargs):
471473
plot_function(ax=subplot.ax, index=index, x=x, y=y, container=container, notifier=self, **plot_config)
472474
elif isinstance(source, ResourceMonitor):
473475
source.plot(plotter=self.plotter, positions=index, **plot_config)
476+
elif data is None:
477+
pass
474478
else:
475479
source_defaults = {'title': name}
476480
if isinstance(data, (tuple, list)) or (isinstance(data, np.ndarray) and data.ndim == 1):

0 commit comments

Comments
 (0)