Skip to content
Merged
39 changes: 28 additions & 11 deletions contextily/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import rasterio as rio
from PIL import Image
from joblib import Memory as _Memory
from joblib import Parallel, delayed
from rasterio.transform import from_origin
from rasterio.io import MemoryFile
from rasterio.vrt import WarpedVRT
Expand Down Expand Up @@ -74,6 +75,7 @@ def bounds2raster(
ll=False,
wait=0,
max_retries=2,
n_connections=1,
):
"""
Take bounding box and zoom, and write tiles into a raster file in
Expand Down Expand Up @@ -113,6 +115,9 @@ def bounds2raster(
[Optional. Default: 2]
total number of rejected requests allowed before contextily
will stop trying to fetch more tiles from a rate-limited API.
n_connections: int
[Optional. Default: 1]
number of connections for downloading tiles in parallel.

Returns
-------
Expand All @@ -126,7 +131,9 @@ def bounds2raster(
w, s = _sm2ll(w, s)
e, n = _sm2ll(e, n)
# Download
Z, ext = bounds2img(w, s, e, n, zoom=zoom, source=source, ll=True)
Z, ext = bounds2img(w, s, e, n, zoom=zoom, source=source, ll=True,
n_connections=n_connections)

# Write
# ---
h, w, b = Z.shape
Expand Down Expand Up @@ -155,7 +162,7 @@ def bounds2raster(


def bounds2img(
w, s, e, n, zoom="auto", source=None, ll=False, wait=0, max_retries=2
w, s, e, n, zoom="auto", source=None, ll=False, wait=0, max_retries=2, n_connections=1
):
"""
Take bounding box and zoom and return an image with all the tiles
Expand Down Expand Up @@ -193,6 +200,9 @@ def bounds2img(
[Optional. Default: 2]
total number of rejected requests allowed before contextily
will stop trying to fetch more tiles from a rate-limited API.
n_connections: int
[Optional. Default: 1]
number of connections for downloading tiles in parallel.

Returns
-------
Expand All @@ -213,15 +223,22 @@ def bounds2img(
if auto_zoom:
zoom = _calculate_zoom(w, s, e, n)
zoom = _validate_zoom(zoom, provider, auto=auto_zoom)
# download and merge tiles
tiles = []
arrays = []
for t in mt.tiles(w, s, e, n, [zoom]):
x, y, z = t.x, t.y, t.z
tile_url = provider.build_url(x=x, y=y, z=z)
image = _fetch_tile(tile_url, wait, max_retries)
tiles.append(t)
arrays.append(image)
# create list of tiles to download
tiles = list(mt.tiles(w, s, e, n, [zoom]))
tile_urls = [provider.build_url(x=tile.x, y=tile.y, z=tile.z) for tile in tiles]
# download tiles
max_connections = 32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for this specific upper limit? I am happy leaving that a responsibility of a user.

if n_connections < 1 or n_connections > max_connections:
raise ValueError(
f"n_connections must be between 1 and {max_connections}"
)
# Use threads for a single connection to avoid the overhead of spawning a process. For multiple connections, use
# processes, as threads lead to memory issues when used in combination with the joblib memory caching (used for
# the _fetch_tile() function).
preferred_backend = "threads" if n_connections == 1 else "processes"
arrays = Parallel(n_jobs=n_connections, prefer=preferred_backend)(
delayed(_fetch_tile)(tile_url, wait, max_retries) for tile_url in tile_urls)
# merge downloaded tiles
merged, extent = _merge_tiles(tiles, arrays)
# lon/lat extent --> Spheric Mercator
west, south, east, north = extent
Expand Down
31 changes: 18 additions & 13 deletions tests/test_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,31 @@ def test_bounds2raster():
assert_array_almost_equal(list(rtr.bounds), rtr_bounds)


@pytest.mark.parametrize("n_connections", [0, 1, 16, 33])
@pytest.mark.network
def test_bounds2img():
def test_bounds2img(n_connections):
w, s, e, n = (
-106.6495132446289,
25.845197677612305,
-93.50721740722656,
36.49387741088867,
)
img, ext = ctx.bounds2img(w, s, e, n, zoom=4, ll=True)
solu = (
-12523442.714243276,
-10018754.171394622,
2504688.5428486555,
5009377.085697309,
)
for i, j in zip(ext, solu):
assert round(i - j, TOL) == 0
assert img[100, 100, :].tolist() == [230, 229, 188, 255]
assert img[100, 200, :].tolist() == [156, 180, 131, 255]
assert img[200, 100, :].tolist() == [230, 225, 189, 255]
if n_connections in [1, 16]: # accepted number of connections
img, ext = ctx.bounds2img(w, s, e, n, zoom=4, ll=True, n_connections=n_connections)
solu = (
-12523442.714243276,
-10018754.171394622,
2504688.5428486555,
5009377.085697309,
)
for i, j in zip(ext, solu):
assert round(i - j, TOL) == 0
assert img[100, 100, :].tolist() == [230, 229, 188, 255]
assert img[100, 200, :].tolist() == [156, 180, 131, 255]
assert img[200, 100, :].tolist() == [230, 225, 189, 255]
else: # too few/many connections should raise an error
with pytest.raises(ValueError):
img, ext = ctx.bounds2img(w, s, e, n, zoom=4, ll=True, n_connections=n_connections)


@pytest.mark.network
Expand Down