Skip to content

Commit 24085dd

Browse files
fix PR comments
1 parent 92b42ff commit 24085dd

File tree

1 file changed

+56
-40
lines changed

1 file changed

+56
-40
lines changed

batchflow/models/torch/base.py

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from torch import nn
1717
from torch.optim.swa_utils import AveragedModel, SWALR
1818

19-
import openvino as ov
20-
import shelve
2119

2220
from sklearn.decomposition import PCA
2321

@@ -1703,8 +1701,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17031701
path, pickle_module=pickle_module, **kwargs)
17041702

17051703
elif use_openvino:
1706-
if batch_size is None:
1707-
raise ValueError('Specify valid `batch_size`, used for model inference!')
1704+
import openvino as ov
17081705

17091706
path_openvino = path_openvino or (path + '_openvino')
17101707
if os.path.splitext(path_openvino)[-1] == '':
@@ -1722,16 +1719,14 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17221719
# Save the rest of parameters
17231720
preserved = set(self.PRESERVE) - set(['model', 'loss', 'optimizer', 'scaler', 'decay'])
17241721
preserved_dict = {item: getattr(self, item) for item in preserved}
1725-
out_path_params = f'{os.path.splitext(path_openvino)[0]}_bf_params_db'
1726-
1727-
with shelve.open(out_path_params) as params_db:
1728-
params_db.update(preserved_dict)
1722+
torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict},
1723+
path, pickle_module=pickle_module, **kwargs)
17291724

17301725
else:
17311726
torch.save({item: getattr(self, item) for item in self.PRESERVE},
17321727
path, pickle_module=pickle_module, **kwargs)
17331728

1734-
def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs):
1729+
def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs):
17351730
""" Load a torch model from a file.
17361731
17371732
If the model was saved in ONNX format (refer to :meth:`.save` for more info), we fix the microbatch size
@@ -1741,8 +1736,6 @@ def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval',
17411736
----------
17421737
file : str, PathLike, io.Bytes
17431738
a file where a model is stored.
1744-
is_openvino : bool, default False
1745-
Whether the load file as openvino model instance.
17461739
make_infrastructure : bool
17471740
Whether to re-create model loss, optimizer, scaler and decay.
17481741
mode : str
@@ -1752,39 +1745,40 @@ def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval',
17521745
kwargs : dict
17531746
Other keyword arguments, passed directly to :func:`torch.save`.
17541747
"""
1755-
self._parse_devices()
1748+
model_load_kwargs = kwargs.pop('model_load_kwargs', {})
17561749

1757-
if is_openvino:
1758-
device = kwargs.pop('device', None) or self.device or 'CPU'
1759-
self.device = device.lower()
1750+
device = kwargs.pop('device', None)
17601751

1761-
model = OVModel(model_path=file, device=device, **kwargs)
1762-
self.model = model
1752+
if device is not None:
1753+
self.device = device
17631754

1764-
# Load params
1765-
out_path_params = f'{os.path.splitext(file)[0]}_bf_params_db'
1766-
with shelve.open(out_path_params) as params_db:
1767-
params = {**params_db}
1755+
if (self.device == 'cpu') or ((not isinstance(self.device, str)) and (self.device.type == 'cpu')):
1756+
self.amp = False
1757+
else:
1758+
self._parse_devices()
17681759

1769-
for key, value in params.items():
1770-
setattr(self, key, value)
1760+
kwargs['map_location'] = self.device
17711761

1772-
self._loaded_from_openvino = True
1773-
self.disable_training = True
1774-
else:
1775-
kwargs['map_location'] = self.device if self.device else 'cpu'
1762+
# Load items from disk storage and set them as insance attributes
1763+
checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs)
17761764

1777-
# Load items from disk storage and set them as insance attributes
1778-
checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs)
1765+
# `load_config` is a reference to `self.external_config` used to update `config`
1766+
# It is required since `self.external_config` may be overwritten in the cycle below
1767+
load_config = self.external_config
17791768

1780-
# `load_config` is a reference to `self.external_config` used to update `config`
1781-
# It is required since `self.external_config` may be overwritten in the cycle below
1782-
load_config = self.external_config
1769+
for key, value in checkpoint.items():
1770+
setattr(self, key, value)
1771+
self.config = self.config + load_config
17831772

1784-
for key, value in checkpoint.items():
1785-
setattr(self, key, value)
1786-
self.config = self.config + load_config
1773+
if 'openvino' in checkpoint:
1774+
# Load openvino model
1775+
model = OVModel(model_path=checkpoint['path_openvino'], **model_load_kwargs)
1776+
self.model = model
17871777

1778+
self._loaded_from_openvino = True
1779+
self.disable_training = True
1780+
1781+
else:
17881782
# Load model from onnx, if needed
17891783
if 'onnx' in checkpoint:
17901784
try:
@@ -1957,25 +1951,47 @@ def reduce_channels(array, normalize=True, n_components=3):
19571951
return compressed_array, explained_variance_ratio
19581952

19591953
class OVModel:
1960-
def __init__(self, model_path, core_config=None, device='CPU', compile_config=None):
1954+
""" Class-wrapper for openvino models to interact with them through :class:`~.TorchModel` interface.
1955+
1956+
Note, openvino models are loaded on 'cpu' only.
1957+
1958+
Parameters
1959+
----------
1960+
model_path : str
1961+
Path to compiled openvino model.
1962+
core_config : tuple or dict, optional
1963+
Openvino core properties.
1964+
If you want set properties globally provide them as tuple: `('CPU', {name: value})`.
1965+
For local properties just provide `{name: value}` dict.
1966+
For more, read the documentation:
1967+
https://docs.openvino.ai/2023.3/openvino_docs_OV_UG_query_api.html#setting-properties-globally
1968+
compile_config : dict, optional
1969+
Openvino model compilation config.
1970+
"""
1971+
def __init__(self, model_path, core_config=None, compile_config=None):
1972+
import openvino as ov
1973+
19611974
core = ov.Core()
19621975

19631976
if core_config is not None:
1964-
for name, kwargs_ in core_config.items():
1965-
core.set_property(name, kwargs_)
1977+
if isinstance(core_config, tuple):
1978+
core.set_property(core_config[0], core_config[1])
1979+
else:
1980+
core.set_property(core_config)
19661981

19671982
self.model = core.read_model(model=model_path)
19681983

19691984
if compile_config is None:
19701985
compile_config = {}
1971-
self.model = core.compile_model(self.model, device, config=compile_config)
1986+
1987+
self.model = core.compile_model(self.model, 'CPU', config=compile_config)
19721988

19731989
def eval(self):
19741990
""" Placeholder for compatibility with :class:`~TorchModel` methods."""
19751991
pass
19761992

19771993
def __call__(self, input_tensor):
1978-
""" Evaluate model on provided data. """
1994+
""" Evaluate model on the provided data. """
19791995
results = self.model(input_tensor)
19801996

19811997
results = torch.from_numpy(results[self.model.output(0)])

0 commit comments

Comments
 (0)