What are some strategies for efficient data analysis using xarray / Dask multiprocessing?
Hi! For efficient data analysis, a helpful first step is to check how your data is stored in the filesystem. If a dataset is stored on disk using chunked formats, the data is divided into fixed-size chunks that are written separately to disk. When reading or processing the data, the smallest readable unit is one full chunk, meaning that even if you need only part of the data, the system must load the entire chunk into memory. To optimize performance, your in-memory chunking should align well with the dataset’s on-disk chunking. This minimizes unnecessary disk reads, reduces recomputation and speeds up processing.
Before loading data, you can inspect its chunking structure using:
ds = xr.open_dataset(url, decode_times=False, chunks={}) # load metadata only
print(ds["tas"]._ChunkSizes) # display chunk sizes on disk
This will help you determine how the data is stored so that you can match it appropriately in memory.
Suppose the on-disk chunking of daily data is [366, 50, 50]
(i.e., 366-day chunks over a 50×50 spatial grid), but you load the dataset into memory with mismatched chunking and perform a resampling operation:
ds = xr.open_dataset(url, chunks={"time": 10, "lat": -1, "lon": -1}) # 10-day time chunks, full spatial extent
tas_resampled = ds['tas'].sel(time=slice('1981-01-01', '1981-12-31'), lat=slice(35, 65), lon=slice(-100, -60)).resample(time='YS').mean().compute()
Here, the resampling of daily data into yearly means requires all 366-day chunks but the in-memory chunking loads only 10 days at a time. This forces Dask to read the entire 366-day chunks repeatedly, leading to 36 redundant reads (366/10 = 36).
A better approach is to align the in-memory chunking with the on-disk chunking:
ds = xr.open_dataset(url, chunks={"time": 366, "lat": 50 * 5, "lon": 50 * 5}) # match time chunking, optimize spatial chunking
tas_resampled = ds['tas'].sel(time=slice('1981-01-01', '1981-12-31'), lat=slice(35, 65), lon=slice(-100, -60)).resample(time='YS').mean().compute()
With this setup, each read operation loads exactly one full year (366 days) at a time, so the annual mean can be computed in a single pass without reloading chunks multiple times. For spatial chunking, setting lat=-1, lon=-1
isn’t ideal because it loads the entire lat-lon grid as a single chunk, which would limit parallelism. A better approach is to set spatial chunking to multiples of the on-disk chunk size, such as lat=50*5, lon=50*5
, so Dask can process multiple spatial chunks per time step while keeping memory usage under control.
Another key step for efficient data analysis is using the Dask Distributed Client, which manages distributed computations across multiple CPU cores. It dynamically schedules tasks and optimizes memory usage across workers.
To take full advantage of this, computations can be defined as delayed tasks using compute=False
. This defers execution and constructs a task graph without immediately loading data into memory. Once all steps are defined, calling .compute()
triggers execution. At this point, Dask analyzes the task graph, distributes the workload efficiently across available resources and executes computations in an optimal order.
Here’s an example of computing yearly means and saving the result to a Zarr file using Dask:
import psutil
import xarray as xr
from dask import compute
from dask.distributed import Client
# function to compute the yearly mean for a given variable
def compute_yearly_mean(ds, var_name):
var = ds[var_name]
yearly_mean = var.resample(time="YS").mean()
return yearly_mean
# set up Dask client within a context manager
with Client(n_workers=5, threads_per_worker=2, memory_limit="4GB") as client:
# display the Dask Dashboard
display(client)
# open the dataset with on-disk chunking structure
url = "https://pavics.ouranos.ca/twitcher/ows/proxy/thredds/dodsC/datasets/reanalyses/day_ERA5-Land_NAM.ncml"
ds = xr.open_dataset(url, chunks={"time": 366, "lat": 50, "lon": 50})
# focus on Eastern North America and a single year for this example
ds = ds.sel(
time=slice("1981-01-01", "1981-12-31"), lat=slice(35, 65), lon=slice(-100, -60)
)
# define the variable to compute yearly means for
variables = ["tas"]
# create a list to hold delayed tasks
tasks = []
for var_name in variables:
output_path = Path(f"var_means/{var_name}_1981_yearly_mean.zarr")
if not output_path.exists():
yearly_mean = compute_yearly_mean(ds, var_name)
# save to Zarr with compute=False to get a delayed task object
delayed_task = yearly_mean.chunk(
{"time": -1, "lat": 50, "lon": 50}
).to_zarr(output_path, mode="w", compute=False)
tasks.append(delayed_task)
# trigger the execution of all delayed tasks
compute(*tasks)
For more examples on using Dask and xarray for parallel processing, check out this Jupyter Notebook on PAVICS: FAQ_dask_parallel.ipynb