Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 214 additions & 74 deletions batchflow/models/torch/base.py

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion batchflow/models/torch/base_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import numpy as np
import torch

from ...plotter import plot
from ...decorators import deprecated

from ...utils_import import try_import
plot = try_import(module='...plotter', package=__name__, attribute='plot',
help='Try `pip install batchflow[image]`!')

# Also imports `tensorboard`, if necessary


Expand Down
125 changes: 125 additions & 0 deletions batchflow/plotter/morphology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Morphological operations implemented with numba to replace cv2 dependency."""

import numpy as np
from numba import njit, prange


@njit
def dilate(image, kernel, iterations=1):
"""Dilate an image using a structuring element.

Parameters
----------
image : numpy.ndarray
Input image to dilate.
kernel : numpy.ndarray
Structuring element (kernel) for dilation. Should contain 1s where
the structuring element is active and 0s elsewhere.
iterations : int, optional
Number of times to apply the dilation. Default is 1.

Returns
-------
numpy.ndarray
Dilated image with the same shape and dtype as input.

"""
result = image.copy()

for _ in range(iterations):
result = _single_dilate(result, kernel)

return result

@njit
def erode(image, kernel, iterations=1):
"""Erode an image using a structuring element.

Parameters
----------
image : numpy.ndarray
Input image to erode.
kernel : numpy.ndarray
Structuring element (kernel) for erosion. Should contain 1s where
the structuring element is active and 0s elsewhere.
iterations : int, optional
Number of times to apply the erosion. Default is 1.

Returns
-------
numpy.ndarray
Eroded image with the same shape and dtype as input.

"""
result = image.copy()

for _ in range(iterations):
result = _single_erode(result, kernel)

return result

@njit(parallel=True)
def _single_dilate(image, kernel):
"""Single iteration of dilation operation."""
height, width = image.shape
kh, kw = kernel.shape
kh_half, kw_half = kh // 2, kw // 2

# Create output array
result = np.zeros_like(image)

# Apply dilation - for each output pixel, find max in kernel neighborhood
for i in prange(height):
for j in range(width):
max_val = image[i, j] # Start with current pixel value

for ki in range(kh):
for kj in range(kw):
if kernel[ki, kj] > 0: # Only consider active kernel elements
# Calculate the source image coordinates
img_i = i + ki - kh_half
img_j = j + kj - kw_half

# Check bounds
if 0 <= img_i < height and 0 <= img_j < width:
if image[img_i, img_j] > max_val:
max_val = image[img_i, img_j]

result[i, j] = max_val

return result

@njit(parallel=True)
def _single_erode(image, kernel):
"""Single iteration of erosion operation."""
height, width = image.shape
kh, kw = kernel.shape
kh_half, kw_half = kh // 2, kw // 2

# Create output array
result = np.zeros_like(image)

# Apply erosion - for each output pixel, find min in kernel neighborhood
for i in prange(height):
for j in range(width):
min_val = image[i, j] # Start with current pixel value

for ki in range(kh):
for kj in range(kw):
if kernel[ki, kj] > 0: # Only consider active kernel elements
# Calculate the source image coordinates
img_i = i + ki - kh_half
img_j = j + kj - kw_half

# Check bounds - treat out of bounds as 0 for erosion
if 0 <= img_i < height and 0 <= img_j < width:
if image[img_i, img_j] < min_val:
min_val = image[img_i, img_j]
else:
# Outside bounds treated as 0, so erosion result should be 0
min_val = 0
break

result[i, j] = min_val

return result
4 changes: 2 additions & 2 deletions batchflow/plotter/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def flatten(self, data):

def dilate(self, data):
""" Apply dilation to array. """
import cv2
from .morphology import dilate
dilation_config = self.config.get('dilate', False)

default_kernel = np.ones((3, 1), dtype=np.uint8)
Expand All @@ -116,7 +116,7 @@ def dilate(self, data):
dilation_config = {'kernel': np.ones(dilation_config, dtype=np.uint8)}
elif 'kernel' in dilation_config and isinstance(dilation_config['kernel'], tuple):
dilation_config['kernel'] = np.ones(dilation_config['kernel'], dtype=np.uint8)
data = cv2.dilate(data.astype(np.float32), **dilation_config)
data = dilate(data.astype(np.float32), **dilation_config)
return data

def mask(self, data):
Expand Down
47 changes: 47 additions & 0 deletions batchflow/tests/model_save_load_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Test for model saving and loading """

import os
import pickle

import pytest
Expand Down Expand Up @@ -186,3 +187,49 @@ def test_bare_model(self, save_path, model_class, pickle_module, outputs):
loaded_predictions = model_load.predict(*args, **kwargs)

assert (np.concatenate(saved_predictions) == np.concatenate(loaded_predictions)).all()

@pytest.mark.parametrize("fmt", [None, 'onnx', 'openvino', 'safetensors'])
@pytest.mark.parametrize("pickle_metadata", [False, True])
def test_save_load_format(self, save_path, model_class, fmt, pickle_metadata):
num_classes = 10
dataset_size = 10
image_shape = (2, 100, 100)

save_kwargs = {
None: {},
'onnx': dict(batch_size=dataset_size),
'openvino': {},
'safetensors': {},
}
load_kwargs = {
None: {},
'onnx': {},
'openvino': {'device': 'cpu'},
'safetensors': {},
}

if fmt == 'openvino' and not pickle_metadata:
save_path = os.path.splitext(save_path)[0] + '.xml'

model_config = {
'classes': num_classes,
'inputs_shapes': image_shape,
'output': 'sigmoid'
}

model_save = model_class(config=model_config)

batch_shape = (dataset_size, *image_shape)
images_array = np.random.random(batch_shape)

inputs = images_array.astype('float32')

saved_predictions = model_save.predict(inputs, outputs='sigmoid')
model_save.save(path=save_path, pickle_metadata=pickle_metadata, fmt=fmt, **save_kwargs[fmt])

load_config = {} if fmt != 'safetensors' else model_save.config
model_load = model_class(config=load_config)
model_load.load(path=save_path, fmt='pt' if pickle_metadata else fmt, **load_kwargs[fmt])
loaded_predictions = model_load.predict(inputs, outputs='sigmoid')

assert np.isclose(np.concatenate(saved_predictions), np.concatenate(loaded_predictions), atol=1e-3).all()
1 change: 1 addition & 0 deletions batchflow/tests/research_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def f(a):
assert research.results.df.iloc[0].a == f(2)
assert research.results.df.iloc[0].b == f(3)

@pytest.mark.slow
@pytest.mark.parametrize('dump_results', [False, True])
@pytest.mark.parametrize('redirect_stdout', [True, 0, 1, 2, 3])
@pytest.mark.parametrize('redirect_stderr', [True, 0, 1, 2, 3])
Expand Down
18 changes: 16 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "batchflow"
version = "0.8.12"
version = "0.9.0"
description = "ML pipelines, model configuration and batch management"
authors = [{ name = "Roman Kh", email = "[email protected]" }]
license = {text = "Apache License 2.0"}
Expand All @@ -25,7 +25,8 @@ dependencies = [
"numba>=0.56",
"llvmlite",
"scipy>=1.9",
"tqdm>=4.19"
"tqdm>=4.19",
"pytest>=8.3.4",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -74,6 +75,19 @@ telegram = [
"pillow>=9.4,<11.0",
]

safetensors = [
"safetensors>=0.5.3",
]

onnx = [
"onnx>=1.14.0",
"onnx2torch>=1.5.0",
]

openvino = [
"openvino>=2025.0.0",
]

other = [
"urllib3>=1.25"
]
Expand Down
Loading