Initial commit

This commit is contained in:
2026-02-01 09:31:38 +01:00
commit e02db93960
4396 changed files with 1511612 additions and 0 deletions

View File

@@ -0,0 +1,84 @@
"""
Small utilities for testing.
"""
import gc
import os
import sys
import sysconfig
from joblib._multiprocessing_helpers import mp
from joblib.testing import SkipTest, skipif
try:
import lz4
except ImportError:
lz4 = None
# TODO straight removal since in joblib.test.common?
IS_PYPY = hasattr(sys, "pypy_version_info")
IS_GIL_DISABLED = (
sysconfig.get_config_var("Py_GIL_DISABLED") and not sys._is_gil_enabled()
)
# A decorator to run tests only when numpy is available
try:
import numpy as np
def with_numpy(func):
"""A decorator to skip tests requiring numpy."""
return func
except ImportError:
def with_numpy(func):
"""A decorator to skip tests requiring numpy."""
def my_func():
raise SkipTest("Test requires numpy")
return my_func
np = None
# TODO: Turn this back on after refactoring yield based tests in test_hashing
# with_numpy = skipif(not np, reason='Test requires numpy.')
# we use memory_profiler library for memory consumption checks
try:
from memory_profiler import memory_usage
def with_memory_profiler(func):
"""A decorator to skip tests requiring memory_profiler."""
return func
def memory_used(func, *args, **kwargs):
"""Compute memory usage when executing func."""
gc.collect()
mem_use = memory_usage((func, args, kwargs), interval=0.001)
return max(mem_use) - min(mem_use)
except ImportError:
def with_memory_profiler(func):
"""A decorator to skip tests requiring memory_profiler."""
def dummy_func():
raise SkipTest("Test requires memory_profiler.")
return dummy_func
memory_usage = memory_used = None
with_multiprocessing = skipif(mp is None, reason="Needs multiprocessing to run.")
with_dev_shm = skipif(
not os.path.exists("/dev/shm"),
reason="This test requires a large /dev/shm shared memory fs.",
)
with_lz4 = skipif(lz4 is None, reason="Needs lz4 compression to run")
without_lz4 = skipif(lz4 is not None, reason="Needs lz4 not being installed to run")

View File

@@ -0,0 +1,106 @@
"""
This script is used to generate test data for joblib/test/test_numpy_pickle.py
"""
import re
import sys
# pytest needs to be able to import this module even when numpy is
# not installed
try:
import numpy as np
except ImportError:
np = None
import joblib
def get_joblib_version(joblib_version=joblib.__version__):
"""Normalize joblib version by removing suffix.
>>> get_joblib_version('0.8.4')
'0.8.4'
>>> get_joblib_version('0.8.4b1')
'0.8.4'
>>> get_joblib_version('0.9.dev0')
'0.9'
"""
matches = [re.match(r"(\d+).*", each) for each in joblib_version.split(".")]
return ".".join([m.group(1) for m in matches if m is not None])
def write_test_pickle(to_pickle, args):
kwargs = {}
compress = args.compress
method = args.method
joblib_version = get_joblib_version()
py_version = "{0[0]}{0[1]}".format(sys.version_info)
numpy_version = "".join(np.__version__.split(".")[:2])
# The game here is to generate the right filename according to the options.
body = "_compressed" if (compress and method == "zlib") else ""
if compress:
if method == "zlib":
kwargs["compress"] = True
extension = ".gz"
else:
kwargs["compress"] = (method, 3)
extension = ".pkl.{}".format(method)
if args.cache_size:
kwargs["cache_size"] = 0
body += "_cache_size"
else:
extension = ".pkl"
pickle_filename = "joblib_{}{}_pickle_py{}_np{}{}".format(
joblib_version, body, py_version, numpy_version, extension
)
try:
joblib.dump(to_pickle, pickle_filename, **kwargs)
except Exception as e:
# With old python version (=< 3.3.), we can arrive there when
# dumping compressed pickle with LzmaFile.
print(
"Error: cannot generate file '{}' with arguments '{}'. "
"Error was: {}".format(pickle_filename, kwargs, e)
)
else:
print("File '{}' generated successfully.".format(pickle_filename))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Joblib pickle data generator.")
parser.add_argument(
"--cache_size",
action="store_true",
help="Force creation of companion numpy files for pickled arrays.",
)
parser.add_argument(
"--compress", action="store_true", help="Generate compress pickles."
)
parser.add_argument(
"--method",
type=str,
default="zlib",
choices=["zlib", "gzip", "bz2", "xz", "lzma", "lz4"],
help="Set compression method.",
)
# We need to be specific about dtypes in particular endianness
# because the pickles can be generated on one architecture and
# the tests run on another one. See
# https://github.com/joblib/joblib/issues/279.
to_pickle = [
np.arange(5, dtype=np.dtype("<i8")),
np.arange(5, dtype=np.dtype("<f8")),
np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
# all possible bytes as a byte string
np.arange(256, dtype=np.uint8).tobytes(),
np.matrix([0, 1, 2], dtype=np.dtype("<i8")),
# unicode string with non-ascii chars
"C'est l'\xe9t\xe9 !",
]
write_test_pickle(to_pickle, parser.parse_args())

View File

@@ -0,0 +1,35 @@
import mmap
from joblib import Parallel, delayed
from joblib.backports import concurrency_safe_rename, make_memmap
from joblib.test.common import with_numpy
from joblib.testing import parametrize
@with_numpy
def test_memmap(tmpdir):
fname = tmpdir.join("test.mmap").strpath
size = 5 * mmap.ALLOCATIONGRANULARITY
offset = mmap.ALLOCATIONGRANULARITY + 1
memmap_obj = make_memmap(fname, shape=size, mode="w+", offset=offset)
assert memmap_obj.offset == offset
@parametrize("dst_content", [None, "dst content"])
@parametrize("backend", [None, "threading"])
def test_concurrency_safe_rename(tmpdir, dst_content, backend):
src_paths = [tmpdir.join("src_%d" % i) for i in range(4)]
for src_path in src_paths:
src_path.write("src content")
dst_path = tmpdir.join("dst")
if dst_content is not None:
dst_path.write(dst_content)
Parallel(n_jobs=4, backend=backend)(
delayed(concurrency_safe_rename)(src_path.strpath, dst_path.strpath)
for src_path in src_paths
)
assert dst_path.exists()
assert dst_path.read() == "src content"
for src_path in src_paths:
assert not src_path.exists()

View File

@@ -0,0 +1,28 @@
"""
Test that our implementation of wrap_non_picklable_objects mimics
properly the loky implementation.
"""
from .._cloudpickle_wrapper import (
_my_wrap_non_picklable_objects,
wrap_non_picklable_objects,
)
def a_function(x):
return x
class AClass(object):
def __call__(self, x):
return x
def test_wrap_non_picklable_objects():
# Mostly a smoke test: test that we can use callable in the same way
# with both our implementation of wrap_non_picklable_objects and the
# upstream one
for obj in (a_function, AClass()):
wrapped_obj = wrap_non_picklable_objects(obj)
my_wrapped_obj = _my_wrap_non_picklable_objects(obj)
assert wrapped_obj(1) == my_wrapped_obj(1)

View File

@@ -0,0 +1,157 @@
import os
from joblib._parallel_backends import (
LokyBackend,
MultiprocessingBackend,
ThreadingBackend,
)
from joblib.parallel import (
BACKENDS,
DEFAULT_BACKEND,
EXTERNAL_BACKENDS,
Parallel,
delayed,
parallel_backend,
parallel_config,
)
from joblib.test.common import np, with_multiprocessing, with_numpy
from joblib.test.test_parallel import check_memmap
from joblib.testing import parametrize, raises
@parametrize("context", [parallel_config, parallel_backend])
def test_global_parallel_backend(context):
default = Parallel()._backend
pb = context("threading")
try:
assert isinstance(Parallel()._backend, ThreadingBackend)
finally:
pb.unregister()
assert type(Parallel()._backend) is type(default)
@parametrize("context", [parallel_config, parallel_backend])
def test_external_backends(context):
def register_foo():
BACKENDS["foo"] = ThreadingBackend
EXTERNAL_BACKENDS["foo"] = register_foo
try:
with context("foo"):
assert isinstance(Parallel()._backend, ThreadingBackend)
finally:
del EXTERNAL_BACKENDS["foo"]
@with_numpy
@with_multiprocessing
def test_parallel_config_no_backend(tmpdir):
# Check that parallel_config allows to change the config
# even if no backend is set.
with parallel_config(n_jobs=2, max_nbytes=1, temp_folder=tmpdir):
with Parallel(prefer="processes") as p:
assert isinstance(p._backend, LokyBackend)
assert p.n_jobs == 2
# Checks that memmapping is enabled
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2)
assert len(os.listdir(tmpdir)) > 0
@with_numpy
@with_multiprocessing
def test_parallel_config_params_explicit_set(tmpdir):
with parallel_config(n_jobs=3, max_nbytes=1, temp_folder=tmpdir):
with Parallel(n_jobs=2, prefer="processes", max_nbytes="1M") as p:
assert isinstance(p._backend, LokyBackend)
assert p.n_jobs == 2
# Checks that memmapping is disabled
with raises(TypeError, match="Expected np.memmap instance"):
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2)
@parametrize("param", ["prefer", "require"])
def test_parallel_config_bad_params(param):
# Check that an error is raised when setting a wrong backend
# hint or constraint
with raises(ValueError, match=f"{param}=wrong is not a valid"):
with parallel_config(**{param: "wrong"}):
Parallel()
def test_parallel_config_constructor_params():
# Check that an error is raised when backend is None
# but backend constructor params are given
with raises(ValueError, match="only supported when backend is not None"):
with parallel_config(inner_max_num_threads=1):
pass
with raises(ValueError, match="only supported when backend is not None"):
with parallel_config(backend_param=1):
pass
with raises(ValueError, match="only supported when backend is a string"):
with parallel_config(backend=BACKENDS[DEFAULT_BACKEND], backend_param=1):
pass
def test_parallel_config_nested():
# Check that nested configuration retrieves the info from the
# parent config and do not reset them.
with parallel_config(n_jobs=2):
p = Parallel()
assert isinstance(p._backend, BACKENDS[DEFAULT_BACKEND])
assert p.n_jobs == 2
with parallel_config(backend="threading"):
with parallel_config(n_jobs=2):
p = Parallel()
assert isinstance(p._backend, ThreadingBackend)
assert p.n_jobs == 2
with parallel_config(verbose=100):
with parallel_config(n_jobs=2):
p = Parallel()
assert p.verbose == 100
assert p.n_jobs == 2
@with_numpy
@with_multiprocessing
@parametrize(
"backend",
["multiprocessing", "threading", MultiprocessingBackend(), ThreadingBackend()],
)
@parametrize("context", [parallel_config, parallel_backend])
def test_threadpool_limitation_in_child_context_error(context, backend):
with raises(AssertionError, match=r"does not acc.*inner_max_num_threads"):
context(backend, inner_max_num_threads=1)
@parametrize("context", [parallel_config, parallel_backend])
def test_parallel_n_jobs_none(context):
# Check that n_jobs=None is interpreted as "unset" in Parallel
# non regression test for #1473
with context(backend="threading", n_jobs=2):
with Parallel(n_jobs=None) as p:
assert p.n_jobs == 2
with context(backend="threading"):
default_n_jobs = Parallel().n_jobs
with Parallel(n_jobs=None) as p:
assert p.n_jobs == default_n_jobs
@parametrize("context", [parallel_config, parallel_backend])
def test_parallel_config_n_jobs_none(context):
# Check that n_jobs=None is interpreted as "explicitly set" in
# parallel_(config/backend)
# non regression test for #1473
with context(backend="threading", n_jobs=2):
with context(backend="threading", n_jobs=None):
# n_jobs=None resets n_jobs to backend's default
with Parallel() as p:
assert p.n_jobs == 1

View File

@@ -0,0 +1,607 @@
from __future__ import absolute_import, division, print_function
import os
import warnings
from random import random
from time import sleep
from uuid import uuid4
import pytest
from .. import Parallel, delayed, parallel_backend, parallel_config
from .._dask import DaskDistributedBackend
from ..parallel import AutoBatchingMixin, ThreadingBackend
from .common import np, with_numpy
from .test_parallel import (
_recursive_backend_info,
_test_deadlock_with_generator,
_test_parallel_unordered_generator_returns_fastest_first, # noqa: E501
)
distributed = pytest.importorskip("distributed")
dask = pytest.importorskip("dask")
# These imports need to be after the pytest.importorskip hence the noqa: E402
from distributed import Client, LocalCluster, get_client # noqa: E402
from distributed.metrics import time # noqa: E402
# Note: pytest requires to manually import all fixtures used in the test
# and their dependencies.
from distributed.utils_test import cleanup, cluster, inc # noqa: E402, F401
@pytest.fixture(scope="function", autouse=True)
def avoid_dask_env_leaks(tmp_path):
# when starting a dask nanny, the environment variable might change.
# this fixture makes sure the environment is reset after the test.
from joblib._parallel_backends import ParallelBackendBase
old_value = {k: os.environ.get(k) for k in ParallelBackendBase.MAX_NUM_THREADS_VARS}
yield
# Reset the environment variables to their original values
for k, v in old_value.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
def noop(*args, **kwargs):
pass
def slow_raise_value_error(condition, duration=0.05):
sleep(duration)
if condition:
raise ValueError("condition evaluated to True")
def count_events(event_name, client):
worker_events = client.run(lambda dask_worker: dask_worker.log)
event_counts = {}
for w, events in worker_events.items():
event_counts[w] = len(
[event for event in list(events) if event[1] == event_name]
)
return event_counts
def test_simple(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask"):
seq = Parallel()(delayed(inc)(i) for i in range(10))
assert seq == [inc(i) for i in range(10)]
with pytest.raises(ValueError):
Parallel()(
delayed(slow_raise_value_error)(i == 3) for i in range(10)
)
seq = Parallel()(delayed(inc)(i) for i in range(10))
assert seq == [inc(i) for i in range(10)]
def test_dask_backend_uses_autobatching(loop):
assert (
DaskDistributedBackend.compute_batch_size
is AutoBatchingMixin.compute_batch_size
)
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask"):
with Parallel() as parallel:
# The backend should be initialized with a default
# batch size of 1:
backend = parallel._backend
assert isinstance(backend, DaskDistributedBackend)
assert backend.parallel is parallel
assert backend._effective_batch_size == 1
# Launch many short tasks that should trigger
# auto-batching:
parallel(delayed(lambda: None)() for _ in range(int(1e4)))
assert backend._effective_batch_size > 10
@pytest.mark.parametrize("n_jobs", [2, -1])
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
def test_parallel_unordered_generator_returns_fastest_first_with_dask(n_jobs, context):
with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
_test_parallel_unordered_generator_returns_fastest_first(None, n_jobs)
@with_numpy
@pytest.mark.parametrize("n_jobs", [2, -1])
@pytest.mark.parametrize("return_as", ["generator", "generator_unordered"])
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
def test_deadlock_with_generator_and_dask(context, return_as, n_jobs):
with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
_test_deadlock_with_generator(None, return_as, n_jobs)
@with_numpy
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
def test_nested_parallelism_with_dask(context):
with distributed.Client(n_workers=2, threads_per_worker=2):
# 10 MB of data as argument to trigger implicit scattering
data = np.ones(int(1e7), dtype=np.uint8)
for i in range(2):
with context("dask"):
backend_types_and_levels = _recursive_backend_info(data=data)
assert len(backend_types_and_levels) == 4
assert all(
name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
)
# No argument
with context("dask"):
backend_types_and_levels = _recursive_backend_info()
assert len(backend_types_and_levels) == 4
assert all(
name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
)
def random2():
return random()
def test_dont_assume_function_purity(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask"):
x, y = Parallel()(delayed(random2)() for i in range(2))
assert x != y
@pytest.mark.parametrize("mixed", [True, False])
def test_dask_funcname(loop, mixed):
from joblib._dask import Batch
if not mixed:
tasks = [delayed(inc)(i) for i in range(4)]
batch_repr = "batch_of_inc_4_calls"
else:
tasks = [delayed(abs)(i) if i % 2 else delayed(inc)(i) for i in range(4)]
batch_repr = "mixed_batch_of_inc_4_calls"
assert repr(Batch(tasks)) == batch_repr
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client:
with parallel_config(backend="dask"):
_ = Parallel(batch_size=2, pre_dispatch="all")(tasks)
def f(dask_scheduler):
return list(dask_scheduler.transition_log)
batch_repr = batch_repr.replace("4", "2")
log = client.run_on_scheduler(f)
assert all("batch_of_inc" in tup[0] for tup in log)
def test_no_undesired_distributed_cache_hit():
# Dask has a pickle cache for callables that are called many times. Because
# the dask backends used to wrap both the functions and the arguments
# under instances of the Batch callable class this caching mechanism could
# lead to bugs as described in: https://github.com/joblib/joblib/pull/1055
# The joblib-dask backend has been refactored to avoid bundling the
# arguments as an attribute of the Batch instance to avoid this problem.
# This test serves as non-regression problem.
# Use a large number of input arguments to give the AutoBatchingMixin
# enough tasks to kick-in.
lists = [[] for _ in range(100)]
np = pytest.importorskip("numpy")
X = np.arange(int(1e6))
def isolated_operation(list_, data=None):
if data is not None:
np.testing.assert_array_equal(data, X)
list_.append(uuid4().hex)
return list_
cluster = LocalCluster(n_workers=1, threads_per_worker=2)
client = Client(cluster)
try:
with parallel_config(backend="dask"):
# dispatches joblib.parallel.BatchedCalls
res = Parallel()(delayed(isolated_operation)(list_) for list_ in lists)
# The original arguments should not have been mutated as the mutation
# happens in the dask worker process.
assert lists == [[] for _ in range(100)]
# Here we did not pass any large numpy array as argument to
# isolated_operation so no scattering event should happen under the
# hood.
counts = count_events("receive-from-scatter", client)
assert sum(counts.values()) == 0
assert all([len(r) == 1 for r in res])
with parallel_config(backend="dask"):
# Append a large array which will be scattered by dask, and
# dispatch joblib._dask.Batch
res = Parallel()(
delayed(isolated_operation)(list_, data=X) for list_ in lists
)
# This time, auto-scattering should have kicked it.
counts = count_events("receive-from-scatter", client)
assert sum(counts.values()) > 0
assert all([len(r) == 1 for r in res])
finally:
client.close(timeout=30)
cluster.close(timeout=30)
class CountSerialized(object):
def __init__(self, x):
self.x = x
self.count = 0
def __add__(self, other):
return self.x + getattr(other, "x", other)
__radd__ = __add__
def __reduce__(self):
self.count += 1
return (CountSerialized, (self.x,))
def add5(a, b, c, d=0, e=0):
return a + b + c + d + e
def test_manual_scatter(loop):
# Let's check that the number of times scattered and non-scattered
# variables are serialized is consistent between `joblib.Parallel` calls
# and equivalent native `client.submit` call.
# Number of serializations can vary from dask to another, so this test only
# checks that `joblib.Parallel` does not add more serialization steps than
# a native `client.submit` call, but does not check for an exact number of
# serialization steps.
w, x, y, z = (CountSerialized(i) for i in range(4))
f = delayed(add5)
tasks = [f(x, y, z, d=4, e=5) for _ in range(10)]
tasks += [
f(x, z, y, d=5, e=4),
f(y, x, z, d=x, e=5),
f(z, z, x, d=z, e=y),
]
expected = [func(*args, **kwargs) for func, args, kwargs in tasks]
with cluster() as (s, _):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask", scatter=[w, x, y]):
results_parallel = Parallel(batch_size=1)(tasks)
assert results_parallel == expected
# Check that an error is raised for bad arguments, as scatter must
# take a list/tuple
with pytest.raises(TypeError):
with parallel_config(backend="dask", loop=loop, scatter=1):
pass
# Scattered variables only serialized during scatter. Checking with an
# extra variable as this count can vary from one dask version
# to another.
n_serialization_scatter_with_parallel = w.count
assert x.count == n_serialization_scatter_with_parallel
assert y.count == n_serialization_scatter_with_parallel
n_serialization_with_parallel = z.count
# Reset the cluster and the serialization count
for var in (w, x, y, z):
var.count = 0
with cluster() as (s, _):
with Client(s["address"], loop=loop) as client: # noqa: F841
scattered = dict()
for obj in w, x, y:
scattered[id(obj)] = client.scatter(obj, broadcast=True)
results_native = [
client.submit(
func,
*(scattered.get(id(arg), arg) for arg in args),
**dict(
(key, scattered.get(id(value), value))
for (key, value) in kwargs.items()
),
key=str(uuid4()),
).result()
for (func, args, kwargs) in tasks
]
assert results_native == expected
# Now check that the number of serialization steps is the same for joblib
# and native dask calls.
n_serialization_scatter_native = w.count
assert x.count == n_serialization_scatter_native
assert y.count == n_serialization_scatter_native
assert n_serialization_scatter_with_parallel == n_serialization_scatter_native
distributed_version = tuple(int(v) for v in distributed.__version__.split("."))
if distributed_version < (2023, 4):
# Previous to 2023.4, the serialization was adding an extra call to
# __reduce__ for the last job `f(z, z, x, d=z, e=y)`, because `z`
# appears both in the args and kwargs, which is not the case when
# running with joblib. Cope with this discrepancy.
assert z.count == n_serialization_with_parallel + 1
else:
assert z.count == n_serialization_with_parallel
# When the same IOLoop is used for multiple clients in a row, use
# loop_in_thread instead of loop to prevent the Client from closing it. See
# dask/distributed #4112
def test_auto_scatter(loop_in_thread):
np = pytest.importorskip("numpy")
data1 = np.ones(int(1e4), dtype=np.uint8)
data2 = np.ones(int(1e4), dtype=np.uint8)
data_to_process = ([data1] * 3) + ([data2] * 3)
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop_in_thread) as client:
with parallel_config(backend="dask"):
# Passing the same data as arg and kwarg triggers a single
# scatter operation whose result is reused.
Parallel()(
delayed(noop)(data, data, i, opt=data)
for i, data in enumerate(data_to_process)
)
# By default large array are automatically scattered with
# broadcast=1 which means that one worker must directly receive
# the data from the scatter operation once.
counts = count_events("receive-from-scatter", client)
assert counts[a["address"]] + counts[b["address"]] == 2
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop_in_thread) as client:
with parallel_config(backend="dask"):
Parallel()(delayed(noop)(data1[:3], i) for i in range(5))
# Small arrays are passed within the task definition without going
# through a scatter operation.
counts = count_events("receive-from-scatter", client)
assert counts[a["address"]] == 0
assert counts[b["address"]] == 0
@pytest.mark.parametrize("retry_no", list(range(2)))
def test_nested_scatter(loop, retry_no):
np = pytest.importorskip("numpy")
NUM_INNER_TASKS = 10
NUM_OUTER_TASKS = 10
def my_sum(x, i, j):
return np.sum(x)
def outer_function_joblib(array, i):
client = get_client() # noqa
with parallel_config(backend="dask"):
results = Parallel()(
delayed(my_sum)(array[j:], i, j) for j in range(NUM_INNER_TASKS)
)
return sum(results)
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as _:
with parallel_config(backend="dask"):
my_array = np.ones(10000)
_ = Parallel()(
delayed(outer_function_joblib)(my_array[i:], i)
for i in range(NUM_OUTER_TASKS)
)
def test_nested_backend_context_manager(loop_in_thread):
def get_nested_pids():
pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
pids |= set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
return pids
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop_in_thread) as client:
with parallel_config(backend="dask"):
pid_groups = Parallel(n_jobs=2)(
delayed(get_nested_pids)() for _ in range(10)
)
for pid_group in pid_groups:
assert len(set(pid_group)) <= 2
# No deadlocks
with Client(s["address"], loop=loop_in_thread) as client: # noqa: F841
with parallel_config(backend="dask"):
pid_groups = Parallel(n_jobs=2)(
delayed(get_nested_pids)() for _ in range(10)
)
for pid_group in pid_groups:
assert len(set(pid_group)) <= 2
def test_nested_backend_context_manager_implicit_n_jobs(loop):
# Check that Parallel with no explicit n_jobs value automatically selects
# all the dask workers, including in nested calls.
def _backend_type(p):
return p._backend.__class__.__name__
def get_nested_implicit_n_jobs():
with Parallel() as p:
return _backend_type(p), p.n_jobs
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask"):
with Parallel() as p:
assert _backend_type(p) == "DaskDistributedBackend"
assert p.n_jobs == -1
all_nested_n_jobs = p(
delayed(get_nested_implicit_n_jobs)() for _ in range(2)
)
for backend_type, nested_n_jobs in all_nested_n_jobs:
assert backend_type == "DaskDistributedBackend"
assert nested_n_jobs == -1
def test_errors(loop):
with pytest.raises(ValueError) as info:
with parallel_config(backend="dask"):
pass
assert "create a dask client" in str(info.value).lower()
def test_correct_nested_backend(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
# No requirement, should be us
with parallel_config(backend="dask"):
result = Parallel(n_jobs=2)(
delayed(outer)(nested_require=None) for _ in range(1)
)
assert isinstance(result[0][0][0], DaskDistributedBackend)
# Require threads, should be threading
with parallel_config(backend="dask"):
result = Parallel(n_jobs=2)(
delayed(outer)(nested_require="sharedmem") for _ in range(1)
)
assert isinstance(result[0][0][0], ThreadingBackend)
def outer(nested_require):
return Parallel(n_jobs=2, prefer="threads")(
delayed(middle)(nested_require) for _ in range(1)
)
def middle(require):
return Parallel(n_jobs=2, require=require)(delayed(inner)() for _ in range(1))
def inner():
return Parallel()._backend
def test_secede_with_no_processes(loop):
# https://github.com/dask/distributed/issues/1775
with Client(loop=loop, processes=False, set_as_default=True):
with parallel_config(backend="dask"):
Parallel(n_jobs=4)(delayed(id)(i) for i in range(2))
def _worker_address(_):
from distributed import get_worker
return get_worker().address
def test_dask_backend_keywords(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask", workers=a["address"]):
seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
assert seq == [a["address"]] * 10
with parallel_config(backend="dask", workers=b["address"]):
seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
assert seq == [b["address"]] * 10
def test_scheduler_tasks_cleanup(loop):
with Client(processes=False, loop=loop) as client:
with parallel_config(backend="dask"):
Parallel()(delayed(inc)(i) for i in range(10))
start = time()
while client.cluster.scheduler.tasks:
sleep(0.01)
assert time() < start + 5
assert not client.futures
@pytest.mark.parametrize("cluster_strategy", ["adaptive", "late_scaling"])
@pytest.mark.skipif(
distributed.__version__ <= "2.1.1" and distributed.__version__ >= "1.28.0",
reason="distributed bug - https://github.com/dask/distributed/pull/2841",
)
def test_wait_for_workers(cluster_strategy):
cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
client = Client(cluster)
if cluster_strategy == "adaptive":
cluster.adapt(minimum=0, maximum=2)
elif cluster_strategy == "late_scaling":
# Tell the cluster to start workers but this is a non-blocking call
# and new workers might take time to connect. In this case the Parallel
# call should wait for at least one worker to come up before starting
# to schedule work.
cluster.scale(2)
try:
with parallel_config(backend="dask"):
# The following should wait a bit for at least one worker to
# become available.
Parallel()(delayed(inc)(i) for i in range(10))
finally:
client.close()
cluster.close()
def test_wait_for_workers_timeout():
# Start a cluster with 0 worker:
cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
client = Client(cluster)
try:
with parallel_config(backend="dask", wait_for_workers_timeout=0.1):
# Short timeout: DaskDistributedBackend
msg = "DaskDistributedBackend has no worker after 0.1 seconds."
with pytest.raises(TimeoutError, match=msg):
Parallel()(delayed(inc)(i) for i in range(10))
with parallel_config(backend="dask", wait_for_workers_timeout=0):
# No timeout: fallback to generic joblib failure:
msg = "DaskDistributedBackend has no active worker"
with pytest.raises(RuntimeError, match=msg):
Parallel()(delayed(inc)(i) for i in range(10))
finally:
client.close()
cluster.close()
@pytest.mark.parametrize("backend", ["loky", "multiprocessing"])
def test_joblib_warning_inside_dask_daemonic_worker(backend):
cluster = LocalCluster(n_workers=2)
client = Client(cluster)
try:
def func_using_joblib_parallel():
# Somehow trying to check the warning type here (e.g. with
# pytest.warns(UserWarning)) make the test hang. Work-around:
# return the warning record to the client and the warning check is
# done client-side.
with warnings.catch_warnings(record=True) as record:
Parallel(n_jobs=2, backend=backend)(delayed(inc)(i) for i in range(10))
return record
fut = client.submit(func_using_joblib_parallel)
record = fut.result()
assert len(record) == 1
warning = record[0].message
assert isinstance(warning, UserWarning)
assert "distributed.worker.daemon" in str(warning)
finally:
client.close(timeout=30)
cluster.close(timeout=30)

View File

@@ -0,0 +1,80 @@
"""
Unit tests for the disk utilities.
"""
# Authors: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# Lars Buitinck
# Copyright (c) 2010 Gael Varoquaux
# License: BSD Style, 3 clauses.
from __future__ import with_statement
import array
import os
from joblib.disk import disk_used, memstr_to_bytes, mkdirp, rm_subdirs
from joblib.testing import parametrize, raises
###############################################################################
def test_disk_used(tmpdir):
cachedir = tmpdir.strpath
# Not write a file that is 1M big in this directory, and check the
# size. The reason we use such a big file is that it makes us robust
# to errors due to block allocation.
a = array.array("i")
sizeof_i = a.itemsize
target_size = 1024
n = int(target_size * 1024 / sizeof_i)
a = array.array("i", n * (1,))
with open(os.path.join(cachedir, "test"), "wb") as output:
a.tofile(output)
assert disk_used(cachedir) >= target_size
assert disk_used(cachedir) < target_size + 12
@parametrize(
"text,value",
[
("80G", 80 * 1024**3),
("1.4M", int(1.4 * 1024**2)),
("120M", 120 * 1024**2),
("53K", 53 * 1024),
],
)
def test_memstr_to_bytes(text, value):
assert memstr_to_bytes(text) == value
@parametrize(
"text,exception,regex",
[
("fooG", ValueError, r"Invalid literal for size.*fooG.*"),
("1.4N", ValueError, r"Invalid literal for size.*1.4N.*"),
],
)
def test_memstr_to_bytes_exception(text, exception, regex):
with raises(exception) as excinfo:
memstr_to_bytes(text)
assert excinfo.match(regex)
def test_mkdirp(tmpdir):
mkdirp(os.path.join(tmpdir.strpath, "ham"))
mkdirp(os.path.join(tmpdir.strpath, "ham"))
mkdirp(os.path.join(tmpdir.strpath, "spam", "spam"))
# Not all OSErrors are ignored
with raises(OSError):
mkdirp("")
def test_rm_subdirs(tmpdir):
sub_path = os.path.join(tmpdir.strpath, "subdir_one", "subdir_two")
full_path = os.path.join(sub_path, "subdir_three")
mkdirp(os.path.join(full_path))
rm_subdirs(sub_path)
assert os.path.exists(sub_path)
assert not os.path.exists(full_path)

View File

@@ -0,0 +1,338 @@
"""
Test the func_inspect module.
"""
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# Copyright (c) 2009 Gael Varoquaux
# License: BSD Style, 3 clauses.
import functools
from joblib.func_inspect import (
_clean_win_chars,
filter_args,
format_signature,
get_func_code,
get_func_name,
)
from joblib.memory import Memory
from joblib.test.common import with_numpy
from joblib.testing import fixture, parametrize, raises
###############################################################################
# Module-level functions and fixture, for tests
def f(x, y=0):
pass
def g(x):
pass
def h(x, y=0, *args, **kwargs):
pass
def i(x=1):
pass
def j(x, y, **kwargs):
pass
def k(*args, **kwargs):
pass
def m1(x, *, y):
pass
def m2(x, *, y, z=3):
pass
@fixture(scope="module")
def cached_func(tmpdir_factory):
# Create a Memory object to test decorated functions.
# We should be careful not to call the decorated functions, so that
# cache directories are not created in the temp dir.
cachedir = tmpdir_factory.mktemp("joblib_test_func_inspect")
mem = Memory(cachedir.strpath)
@mem.cache
def cached_func_inner(x):
return x
return cached_func_inner
class Klass(object):
def f(self, x):
return x
###############################################################################
# Tests
@parametrize(
"func,args,filtered_args",
[
(f, [[], (1,)], {"x": 1, "y": 0}),
(f, [["x"], (1,)], {"y": 0}),
(f, [["y"], (0,)], {"x": 0}),
(f, [["y"], (0,), {"y": 1}], {"x": 0}),
(f, [["x", "y"], (0,)], {}),
(f, [[], (0,), {"y": 1}], {"x": 0, "y": 1}),
(f, [["y"], (), {"x": 2, "y": 1}], {"x": 2}),
(g, [[], (), {"x": 1}], {"x": 1}),
(i, [[], (2,)], {"x": 2}),
],
)
def test_filter_args(func, args, filtered_args):
assert filter_args(func, *args) == filtered_args
def test_filter_args_method():
obj = Klass()
assert filter_args(obj.f, [], (1,)) == {"x": 1, "self": obj}
@parametrize(
"func,args,filtered_args",
[
(h, [[], (1,)], {"x": 1, "y": 0, "*": [], "**": {}}),
(h, [[], (1, 2, 3, 4)], {"x": 1, "y": 2, "*": [3, 4], "**": {}}),
(h, [[], (1, 25), {"ee": 2}], {"x": 1, "y": 25, "*": [], "**": {"ee": 2}}),
(h, [["*"], (1, 2, 25), {"ee": 2}], {"x": 1, "y": 2, "**": {"ee": 2}}),
],
)
def test_filter_varargs(func, args, filtered_args):
assert filter_args(func, *args) == filtered_args
test_filter_kwargs_extra_params = [
(m1, [[], (1,), {"y": 2}], {"x": 1, "y": 2}),
(m2, [[], (1,), {"y": 2}], {"x": 1, "y": 2, "z": 3}),
]
@parametrize(
"func,args,filtered_args",
[
(k, [[], (1, 2), {"ee": 2}], {"*": [1, 2], "**": {"ee": 2}}),
(k, [[], (3, 4)], {"*": [3, 4], "**": {}}),
]
+ test_filter_kwargs_extra_params,
)
def test_filter_kwargs(func, args, filtered_args):
assert filter_args(func, *args) == filtered_args
def test_filter_args_2():
assert filter_args(j, [], (1, 2), {"ee": 2}) == {"x": 1, "y": 2, "**": {"ee": 2}}
ff = functools.partial(f, 1)
# filter_args has to special-case partial
assert filter_args(ff, [], (1,)) == {"*": [1], "**": {}}
assert filter_args(ff, ["y"], (1,)) == {"*": [1], "**": {}}
@parametrize("func,funcname", [(f, "f"), (g, "g"), (cached_func, "cached_func")])
def test_func_name(func, funcname):
# Check that we are not confused by decoration
# here testcase 'cached_func' is the function itself
assert get_func_name(func)[1] == funcname
def test_func_name_on_inner_func(cached_func):
# Check that we are not confused by decoration
# here testcase 'cached_func' is the 'cached_func_inner' function
# returned by 'cached_func' fixture
assert get_func_name(cached_func)[1] == "cached_func_inner"
def test_func_name_collision_on_inner_func():
# Check that two functions defining and caching an inner function
# with the same do not cause (module, name) collision
def f():
def inner_func():
return # pragma: no cover
return get_func_name(inner_func)
def g():
def inner_func():
return # pragma: no cover
return get_func_name(inner_func)
module, name = f()
other_module, other_name = g()
assert name == other_name
assert module != other_module
def test_func_inspect_errors():
# Check that func_inspect is robust and will work on weird objects
assert get_func_name("a".lower)[-1] == "lower"
assert get_func_code("a".lower)[1:] == (None, -1)
ff = lambda x: x # noqa: E731
assert get_func_name(ff, win_characters=False)[-1] == "<lambda>"
assert get_func_code(ff)[1] == __file__.replace(".pyc", ".py")
# Simulate a function defined in __main__
ff.__module__ = "__main__"
assert get_func_name(ff, win_characters=False)[-1] == "<lambda>"
assert get_func_code(ff)[1] == __file__.replace(".pyc", ".py")
def func_with_kwonly_args(a, b, *, kw1="kw1", kw2="kw2"):
pass
def func_with_signature(a: int, b: int) -> None:
pass
def test_filter_args_edge_cases():
assert filter_args(func_with_kwonly_args, [], (1, 2), {"kw1": 3, "kw2": 4}) == {
"a": 1,
"b": 2,
"kw1": 3,
"kw2": 4,
}
# filter_args doesn't care about keyword-only arguments so you
# can pass 'kw1' into *args without any problem
with raises(ValueError) as excinfo:
filter_args(func_with_kwonly_args, [], (1, 2, 3), {"kw2": 2})
excinfo.match("Keyword-only parameter 'kw1' was passed as positional parameter")
assert filter_args(
func_with_kwonly_args, ["b", "kw2"], (1, 2), {"kw1": 3, "kw2": 4}
) == {"a": 1, "kw1": 3}
assert filter_args(func_with_signature, ["b"], (1, 2)) == {"a": 1}
def test_bound_methods():
"""Make sure that calling the same method on two different instances
of the same class does resolv to different signatures.
"""
a = Klass()
b = Klass()
assert filter_args(a.f, [], (1,)) != filter_args(b.f, [], (1,))
@parametrize(
"exception,regex,func,args",
[
(
ValueError,
"ignore_lst must be a list of parameters to ignore",
f,
["bar", (None,)],
),
(
ValueError,
r"Ignore list: argument \'(.*)\' is not defined",
g,
[["bar"], (None,)],
),
(ValueError, "Wrong number of arguments", h, [[]]),
],
)
def test_filter_args_error_msg(exception, regex, func, args):
"""Make sure that filter_args returns decent error messages, for the
sake of the user.
"""
with raises(exception) as excinfo:
filter_args(func, *args)
excinfo.match(regex)
def test_filter_args_no_kwargs_mutation():
"""None-regression test against 0.12.0 changes.
https://github.com/joblib/joblib/pull/75
Make sure filter args doesn't mutate the kwargs dict that gets passed in.
"""
kwargs = {"x": 0}
filter_args(g, [], [], kwargs)
assert kwargs == {"x": 0}
def test_clean_win_chars():
string = r"C:\foo\bar\main.py"
mangled_string = _clean_win_chars(string)
for char in ("\\", ":", "<", ">", "!"):
assert char not in mangled_string
@parametrize(
"func,args,kwargs,sgn_expected",
[
(g, [list(range(5))], {}, "g([0, 1, 2, 3, 4])"),
(k, [1, 2, (3, 4)], {"y": True}, "k(1, 2, (3, 4), y=True)"),
],
)
def test_format_signature(func, args, kwargs, sgn_expected):
# Test signature formatting.
path, sgn_result = format_signature(func, *args, **kwargs)
assert sgn_result == sgn_expected
def test_format_signature_long_arguments():
shortening_threshold = 1500
# shortening gets it down to 700 characters but there is the name
# of the function in the signature and a few additional things
# like dots for the ellipsis
shortening_target = 700 + 10
arg = "a" * shortening_threshold
_, signature = format_signature(h, arg)
assert len(signature) < shortening_target
nb_args = 5
args = [arg for _ in range(nb_args)]
_, signature = format_signature(h, *args)
assert len(signature) < shortening_target * nb_args
kwargs = {str(i): arg for i, arg in enumerate(args)}
_, signature = format_signature(h, **kwargs)
assert len(signature) < shortening_target * nb_args
_, signature = format_signature(h, *args, **kwargs)
assert len(signature) < shortening_target * 2 * nb_args
@with_numpy
def test_format_signature_numpy():
"""Test the format signature formatting with numpy."""
def test_special_source_encoding():
from joblib.test.test_func_inspect_special_encoding import big5_f
func_code, source_file, first_line = get_func_code(big5_f)
assert first_line == 5
assert "def big5_f():" in func_code
assert "test_func_inspect_special_encoding" in source_file
def _get_code():
from joblib.test.test_func_inspect_special_encoding import big5_f
return get_func_code(big5_f)[0]
def test_func_code_consistency():
from joblib.parallel import Parallel, delayed
codes = Parallel(n_jobs=2)(delayed(_get_code)() for _ in range(5))
assert len(set(codes)) == 1

View File

@@ -0,0 +1,9 @@
# -*- coding: big5 -*-
# Some Traditional Chinese characters: 一些中文字符
def big5_f():
"""用於測試的函數
"""
# 註釋
return 0

View File

@@ -0,0 +1,520 @@
"""
Test the hashing module.
"""
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# Copyright (c) 2009 Gael Varoquaux
# License: BSD Style, 3 clauses.
import collections
import gc
import hashlib
import io
import itertools
import pickle
import random
import sys
import time
from concurrent.futures import ProcessPoolExecutor
from decimal import Decimal
from joblib.func_inspect import filter_args
from joblib.hashing import hash
from joblib.memory import Memory
from joblib.test.common import np, with_numpy
from joblib.testing import fixture, parametrize, raises, skipif
def unicode(s):
return s
###############################################################################
# Helper functions for the tests
def time_func(func, *args):
"""Time function func on *args."""
times = list()
for _ in range(3):
t1 = time.time()
func(*args)
times.append(time.time() - t1)
return min(times)
def relative_time(func1, func2, *args):
"""Return the relative time between func1 and func2 applied on
*args.
"""
time_func1 = time_func(func1, *args)
time_func2 = time_func(func2, *args)
relative_diff = 0.5 * (abs(time_func1 - time_func2) / (time_func1 + time_func2))
return relative_diff
class Klass(object):
def f(self, x):
return x
class KlassWithCachedMethod(object):
def __init__(self, cachedir):
mem = Memory(location=cachedir)
self.f = mem.cache(self.f)
def f(self, x):
return x
###############################################################################
# Tests
input_list = [
1,
2,
1.0,
2.0,
1 + 1j,
2.0 + 1j,
"a",
"b",
(1,),
(
1,
1,
),
[
1,
],
[
1,
1,
],
{1: 1},
{1: 2},
{2: 1},
None,
gc.collect,
[
1,
].append,
# Next 2 sets have unorderable elements in python 3.
set(("a", 1)),
set(("a", 1, ("a", 1))),
# Next 2 dicts have unorderable type of keys in python 3.
{"a": 1, 1: 2},
{"a": 1, 1: 2, "d": {"a": 1}},
]
@parametrize("obj1", input_list)
@parametrize("obj2", input_list)
def test_trivial_hash(obj1, obj2):
"""Smoke test hash on various types."""
# Check that 2 objects have the same hash only if they are the same.
are_hashes_equal = hash(obj1) == hash(obj2)
are_objs_identical = obj1 is obj2
assert are_hashes_equal == are_objs_identical
def test_hash_methods():
# Check that hashing instance methods works
a = io.StringIO(unicode("a"))
assert hash(a.flush) == hash(a.flush)
a1 = collections.deque(range(10))
a2 = collections.deque(range(9))
assert hash(a1.extend) != hash(a2.extend)
@fixture(scope="function")
@with_numpy
def three_np_arrays():
rnd = np.random.RandomState(0)
arr1 = rnd.random_sample((10, 10))
arr2 = arr1.copy()
arr3 = arr2.copy()
arr3[0] += 1
return arr1, arr2, arr3
def test_hash_numpy_arrays(three_np_arrays):
arr1, arr2, arr3 = three_np_arrays
for obj1, obj2 in itertools.product(three_np_arrays, repeat=2):
are_hashes_equal = hash(obj1) == hash(obj2)
are_arrays_equal = np.all(obj1 == obj2)
assert are_hashes_equal == are_arrays_equal
assert hash(arr1) != hash(arr1.T)
def test_hash_numpy_dict_of_arrays(three_np_arrays):
arr1, arr2, arr3 = three_np_arrays
d1 = {1: arr1, 2: arr2}
d2 = {1: arr2, 2: arr1}
d3 = {1: arr2, 2: arr3}
assert hash(d1) == hash(d2)
assert hash(d1) != hash(d3)
@with_numpy
@parametrize("dtype", ["datetime64[s]", "timedelta64[D]"])
def test_numpy_datetime_array(dtype):
# memoryview is not supported for some dtypes e.g. datetime64
# see https://github.com/joblib/joblib/issues/188 for more details
a_hash = hash(np.arange(10))
array = np.arange(0, 10, dtype=dtype)
assert hash(array) != a_hash
@with_numpy
def test_hash_numpy_noncontiguous():
a = np.asarray(np.arange(6000).reshape((1000, 2, 3)), order="F")[:, :1, :]
b = np.ascontiguousarray(a)
assert hash(a) != hash(b)
c = np.asfortranarray(a)
assert hash(a) != hash(c)
@with_numpy
@parametrize("coerce_mmap", [True, False])
def test_hash_memmap(tmpdir, coerce_mmap):
"""Check that memmap and arrays hash identically if coerce_mmap is True."""
filename = tmpdir.join("memmap_temp").strpath
try:
m = np.memmap(filename, shape=(10, 10), mode="w+")
a = np.asarray(m)
are_hashes_equal = hash(a, coerce_mmap=coerce_mmap) == hash(
m, coerce_mmap=coerce_mmap
)
assert are_hashes_equal == coerce_mmap
finally:
if "m" in locals():
del m
# Force a garbage-collection cycle, to be certain that the
# object is delete, and we don't run in a problem under
# Windows with a file handle still open.
gc.collect()
@with_numpy
@skipif(
sys.platform == "win32",
reason="This test is not stable under windows for some reason",
)
def test_hash_numpy_performance():
"""Check the performance of hashing numpy arrays:
In [22]: a = np.random.random(1000000)
In [23]: %timeit hashlib.md5(a).hexdigest()
100 loops, best of 3: 20.7 ms per loop
In [24]: %timeit hashlib.md5(pickle.dumps(a, protocol=2)).hexdigest()
1 loops, best of 3: 73.1 ms per loop
In [25]: %timeit hashlib.md5(cPickle.dumps(a, protocol=2)).hexdigest()
10 loops, best of 3: 53.9 ms per loop
In [26]: %timeit hash(a)
100 loops, best of 3: 20.8 ms per loop
"""
rnd = np.random.RandomState(0)
a = rnd.random_sample(1000000)
def md5_hash(x):
return hashlib.md5(memoryview(x)).hexdigest()
relative_diff = relative_time(md5_hash, hash, a)
assert relative_diff < 0.3
# Check that hashing an tuple of 3 arrays takes approximately
# 3 times as much as hashing one array
time_hashlib = 3 * time_func(md5_hash, a)
time_hash = time_func(hash, (a, a, a))
relative_diff = 0.5 * (abs(time_hash - time_hashlib) / (time_hash + time_hashlib))
assert relative_diff < 0.3
def test_bound_methods_hash():
"""Make sure that calling the same method on two different instances
of the same class does resolve to the same hashes.
"""
a = Klass()
b = Klass()
assert hash(filter_args(a.f, [], (1,))) == hash(filter_args(b.f, [], (1,)))
def test_bound_cached_methods_hash(tmpdir):
"""Make sure that calling the same _cached_ method on two different
instances of the same class does resolve to the same hashes.
"""
a = KlassWithCachedMethod(tmpdir.strpath)
b = KlassWithCachedMethod(tmpdir.strpath)
assert hash(filter_args(a.f.func, [], (1,))) == hash(
filter_args(b.f.func, [], (1,))
)
@with_numpy
def test_hash_object_dtype():
"""Make sure that ndarrays with dtype `object' hash correctly."""
a = np.array([np.arange(i) for i in range(6)], dtype=object)
b = np.array([np.arange(i) for i in range(6)], dtype=object)
assert hash(a) == hash(b)
@with_numpy
def test_numpy_scalar():
# Numpy scalars are built from compiled functions, and lead to
# strange pickling paths explored, that can give hash collisions
a = np.float64(2.0)
b = np.float64(3.0)
assert hash(a) != hash(b)
def test_dict_hash(tmpdir):
# Check that dictionaries hash consistently, even though the ordering
# of the keys is not guaranteed
k = KlassWithCachedMethod(tmpdir.strpath)
d = {
"#s12069__c_maps.nii.gz": [33],
"#s12158__c_maps.nii.gz": [33],
"#s12258__c_maps.nii.gz": [33],
"#s12277__c_maps.nii.gz": [33],
"#s12300__c_maps.nii.gz": [33],
"#s12401__c_maps.nii.gz": [33],
"#s12430__c_maps.nii.gz": [33],
"#s13817__c_maps.nii.gz": [33],
"#s13903__c_maps.nii.gz": [33],
"#s13916__c_maps.nii.gz": [33],
"#s13981__c_maps.nii.gz": [33],
"#s13982__c_maps.nii.gz": [33],
"#s13983__c_maps.nii.gz": [33],
}
a = k.f(d)
b = k.f(a)
assert hash(a) == hash(b)
def test_set_hash(tmpdir):
# Check that sets hash consistently, even though their ordering
# is not guaranteed
k = KlassWithCachedMethod(tmpdir.strpath)
s = set(
[
"#s12069__c_maps.nii.gz",
"#s12158__c_maps.nii.gz",
"#s12258__c_maps.nii.gz",
"#s12277__c_maps.nii.gz",
"#s12300__c_maps.nii.gz",
"#s12401__c_maps.nii.gz",
"#s12430__c_maps.nii.gz",
"#s13817__c_maps.nii.gz",
"#s13903__c_maps.nii.gz",
"#s13916__c_maps.nii.gz",
"#s13981__c_maps.nii.gz",
"#s13982__c_maps.nii.gz",
"#s13983__c_maps.nii.gz",
]
)
a = k.f(s)
b = k.f(a)
assert hash(a) == hash(b)
def test_set_decimal_hash():
# Check that sets containing decimals hash consistently, even though
# ordering is not guaranteed
assert hash(set([Decimal(0), Decimal("NaN")])) == hash(
set([Decimal("NaN"), Decimal(0)])
)
def test_string():
# Test that we obtain the same hash for object owning several strings,
# whatever the past of these strings (which are immutable in Python)
string = "foo"
a = {string: "bar"}
b = {string: "bar"}
c = pickle.loads(pickle.dumps(b))
assert hash([a, b]) == hash([a, c])
@with_numpy
def test_numpy_dtype_pickling():
# numpy dtype hashing is tricky to get right: see #231, #239, #251 #1080,
# #1082, and explanatory comments inside
# ``joblib.hashing.NumpyHasher.save``.
# In this test, we make sure that the pickling of numpy dtypes is robust to
# object identity and object copy.
dt1 = np.dtype("f4")
dt2 = np.dtype("f4")
# simple dtypes objects are interned
assert dt1 is dt2
assert hash(dt1) == hash(dt2)
dt1_roundtripped = pickle.loads(pickle.dumps(dt1))
assert dt1 is not dt1_roundtripped
assert hash(dt1) == hash(dt1_roundtripped)
assert hash([dt1, dt1]) == hash([dt1_roundtripped, dt1_roundtripped])
assert hash([dt1, dt1]) == hash([dt1, dt1_roundtripped])
complex_dt1 = np.dtype([("name", np.str_, 16), ("grades", np.float64, (2,))])
complex_dt2 = np.dtype([("name", np.str_, 16), ("grades", np.float64, (2,))])
# complex dtypes objects are not interned
assert hash(complex_dt1) == hash(complex_dt2)
complex_dt1_roundtripped = pickle.loads(pickle.dumps(complex_dt1))
assert complex_dt1_roundtripped is not complex_dt1
assert hash(complex_dt1) == hash(complex_dt1_roundtripped)
assert hash([complex_dt1, complex_dt1]) == hash(
[complex_dt1_roundtripped, complex_dt1_roundtripped]
)
assert hash([complex_dt1, complex_dt1]) == hash(
[complex_dt1_roundtripped, complex_dt1]
)
@parametrize(
"to_hash,expected",
[
("This is a string to hash", "71b3f47df22cb19431d85d92d0b230b2"),
("C'est l\xe9t\xe9", "2d8d189e9b2b0b2e384d93c868c0e576"),
((123456, 54321, -98765), "e205227dd82250871fa25aa0ec690aa3"),
(
[random.Random(42).random() for _ in range(5)],
"a11ffad81f9682a7d901e6edc3d16c84",
),
({"abcde": 123, "sadfas": [-9999, 2, 3]}, "aeda150553d4bb5c69f0e69d51b0e2ef"),
],
)
def test_hashes_stay_the_same(to_hash, expected):
# We want to make sure that hashes don't change with joblib
# version. For end users, that would mean that they have to
# regenerate their cache from scratch, which potentially means
# lengthy recomputations.
# Expected results have been generated with joblib 0.9.2
assert hash(to_hash) == expected
@with_numpy
def test_hashes_are_different_between_c_and_fortran_contiguous_arrays():
# We want to be sure that the c-contiguous and f-contiguous versions of the
# same array produce 2 different hashes.
rng = np.random.RandomState(0)
arr_c = rng.random_sample((10, 10))
arr_f = np.asfortranarray(arr_c)
assert hash(arr_c) != hash(arr_f)
@with_numpy
def test_0d_array():
hash(np.array(0))
@with_numpy
def test_0d_and_1d_array_hashing_is_different():
assert hash(np.array(0)) != hash(np.array([0]))
@with_numpy
def test_hashes_stay_the_same_with_numpy_objects():
# Note: joblib used to test numpy objects hashing by comparing the produced
# hash of an object with some hard-coded target value to guarantee that
# hashing remains the same across joblib versions. However, since numpy
# 1.20 and joblib 1.0, joblib relies on potentially unstable implementation
# details of numpy to hash np.dtype objects, which makes the stability of
# hash values across different environments hard to guarantee and to test.
# As a result, hashing stability across joblib versions becomes best-effort
# only, and we only test the consistency within a single environment by
# making sure:
# - the hash of two copies of the same objects is the same
# - hashing some object in two different python processes produces the same
# value. This should be viewed as a proxy for testing hash consistency
# through time between Python sessions (provided no change in the
# environment was done between sessions).
def create_objects_to_hash():
rng = np.random.RandomState(42)
# Being explicit about dtypes in order to avoid
# architecture-related differences. Also using 'f4' rather than
# 'f8' for float arrays because 'f8' arrays generated by
# rng.random.randn don't seem to be bit-identical on 32bit and
# 64bit machines.
to_hash_list = [
rng.randint(-1000, high=1000, size=50).astype("<i8"),
tuple(rng.randn(3).astype("<f4") for _ in range(5)),
[rng.randn(3).astype("<f4") for _ in range(5)],
{
-3333: rng.randn(3, 5).astype("<f4"),
0: [
rng.randint(10, size=20).astype("<i8"),
rng.randn(10).astype("<f4"),
],
},
# Non regression cases for
# https://github.com/joblib/joblib/issues/308
np.arange(100, dtype="<i8").reshape((10, 10)),
# Fortran contiguous array
np.asfortranarray(np.arange(100, dtype="<i8").reshape((10, 10))),
# Non contiguous array
np.arange(100, dtype="<i8").reshape((10, 10))[:, :2],
]
return to_hash_list
# Create two lists containing copies of the same objects. joblib.hash
# should return the same hash for to_hash_list_one[i] and
# to_hash_list_two[i]
to_hash_list_one = create_objects_to_hash()
to_hash_list_two = create_objects_to_hash()
e1 = ProcessPoolExecutor(max_workers=1)
e2 = ProcessPoolExecutor(max_workers=1)
try:
for obj_1, obj_2 in zip(to_hash_list_one, to_hash_list_two):
# testing consistency of hashes across python processes
hash_1 = e1.submit(hash, obj_1).result()
hash_2 = e2.submit(hash, obj_1).result()
assert hash_1 == hash_2
# testing consistency when hashing two copies of the same objects.
hash_3 = e1.submit(hash, obj_2).result()
assert hash_1 == hash_3
finally:
e1.shutdown()
e2.shutdown()
def test_hashing_pickling_error():
def non_picklable():
return 42
with raises(pickle.PicklingError) as excinfo:
hash(non_picklable)
excinfo.match("PicklingError while hashing")
def test_wrong_hash_name():
msg = "Valid options for 'hash_name' are"
with raises(ValueError, match=msg):
data = {"foo": "bar"}
hash(data, hash_name="invalid")

View File

@@ -0,0 +1,15 @@
# Basic test case to test functioning of module's top-level
try:
from joblib import * # noqa
_top_import_error = None
except Exception as ex: # pragma: no cover
_top_import_error = ex
def test_import_joblib():
# Test either above import has failed for some reason
# "import *" only allowed at module level, hence we
# rely on setting up the variable above
assert _top_import_error is None

View File

@@ -0,0 +1,29 @@
"""
Test the logger module.
"""
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# Copyright (c) 2009 Gael Varoquaux
# License: BSD Style, 3 clauses.
import re
from joblib.logger import PrintTime
def test_print_time(tmpdir, capsys):
# A simple smoke test for PrintTime.
logfile = tmpdir.join("test.log").strpath
print_time = PrintTime(logfile=logfile)
print_time("Foo")
# Create a second time, to smoke test log rotation.
print_time = PrintTime(logfile=logfile)
print_time("Foo")
# And a third time
print_time = PrintTime(logfile=logfile)
print_time("Foo")
out_printed_text, err_printed_text = capsys.readouterr()
# Use regexps to be robust to time variations
match = r"Foo: 0\..s, 0\..min\nFoo: 0\..s, 0..min\nFoo: " + r".\..s, 0..min\n"
if not re.match(match, err_printed_text):
raise AssertionError("Excepted %s, got %s" % (match, err_printed_text))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,180 @@
import asyncio
import gc
import shutil
import pytest
from joblib.memory import (
AsyncMemorizedFunc,
AsyncNotMemorizedFunc,
MemorizedResult,
Memory,
NotMemorizedResult,
)
from joblib.test.common import np, with_numpy
from joblib.testing import raises
from .test_memory import corrupt_single_cache_item, monkeypatch_cached_func_warn
async def check_identity_lazy_async(func, accumulator, location):
"""Similar to check_identity_lazy_async for coroutine functions"""
memory = Memory(location=location, verbose=0)
func = memory.cache(func)
for i in range(3):
for _ in range(2):
value = await func(i)
assert value == i
assert len(accumulator) == i + 1
@pytest.mark.asyncio
async def test_memory_integration_async(tmpdir):
accumulator = list()
async def f(n):
await asyncio.sleep(0.1)
accumulator.append(1)
return n
await check_identity_lazy_async(f, accumulator, tmpdir.strpath)
# Now test clearing
for compress in (False, True):
for mmap_mode in ("r", None):
memory = Memory(
location=tmpdir.strpath,
verbose=10,
mmap_mode=mmap_mode,
compress=compress,
)
# First clear the cache directory, to check that our code can
# handle that
# NOTE: this line would raise an exception, as the database
# file is still open; we ignore the error since we want to
# test what happens if the directory disappears
shutil.rmtree(tmpdir.strpath, ignore_errors=True)
g = memory.cache(f)
await g(1)
g.clear(warn=False)
current_accumulator = len(accumulator)
out = await g(1)
assert len(accumulator) == current_accumulator + 1
# Also, check that Memory.eval works similarly
evaled = await memory.eval(f, 1)
assert evaled == out
assert len(accumulator) == current_accumulator + 1
# Now do a smoke test with a function defined in __main__, as the name
# mangling rules are more complex
f.__module__ = "__main__"
memory = Memory(location=tmpdir.strpath, verbose=0)
await memory.cache(f)(1)
@pytest.mark.asyncio
async def test_no_memory_async():
accumulator = list()
async def ff(x):
await asyncio.sleep(0.1)
accumulator.append(1)
return x
memory = Memory(location=None, verbose=0)
gg = memory.cache(ff)
for _ in range(4):
current_accumulator = len(accumulator)
await gg(1)
assert len(accumulator) == current_accumulator + 1
@with_numpy
@pytest.mark.asyncio
async def test_memory_numpy_check_mmap_mode_async(tmpdir, monkeypatch):
"""Check that mmap_mode is respected even at the first call"""
memory = Memory(location=tmpdir.strpath, mmap_mode="r", verbose=0)
@memory.cache()
async def twice(a):
return a * 2
a = np.ones(3)
b = await twice(a)
c = await twice(a)
assert isinstance(c, np.memmap)
assert c.mode == "r"
assert isinstance(b, np.memmap)
assert b.mode == "r"
# Corrupts the file, Deleting b and c mmaps
# is necessary to be able edit the file
del b
del c
gc.collect()
corrupt_single_cache_item(memory)
# Make sure that corrupting the file causes recomputation and that
# a warning is issued.
recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
d = await twice(a)
assert len(recorded_warnings) == 1
exception_msg = "Exception while loading results"
assert exception_msg in recorded_warnings[0]
# Asserts that the recomputation returns a mmap
assert isinstance(d, np.memmap)
assert d.mode == "r"
@pytest.mark.asyncio
async def test_call_and_shelve_async(tmpdir):
async def f(x, y=1):
await asyncio.sleep(0.1)
return x**2 + y
# Test MemorizedFunc outputting a reference to cache.
for func, Result in zip(
(
AsyncMemorizedFunc(f, tmpdir.strpath),
AsyncNotMemorizedFunc(f),
Memory(location=tmpdir.strpath, verbose=0).cache(f),
Memory(location=None).cache(f),
),
(
MemorizedResult,
NotMemorizedResult,
MemorizedResult,
NotMemorizedResult,
),
):
for _ in range(2):
result = await func.call_and_shelve(2)
assert isinstance(result, Result)
assert result.get() == 5
result.clear()
with raises(KeyError):
result.get()
result.clear() # Do nothing if there is no cache.
@pytest.mark.asyncio
async def test_memorized_func_call_async(memory):
async def ff(x, counter):
await asyncio.sleep(0.1)
counter[x] = counter.get(x, 0) + 1
return counter[x]
gg = memory.cache(ff, ignore=["counter"])
counter = {}
assert await gg(2, counter) == 1
assert await gg(2, counter) == 1
x, meta = await gg.call(2, counter)
assert x == 2, "f has not been called properly"
assert isinstance(meta, dict), "Metadata are not returned by MemorizedFunc.call."

View File

@@ -0,0 +1,36 @@
"""
Pyodide and other single-threaded Python builds will be missing the
_multiprocessing module. Test that joblib still works in this environment.
"""
import os
import subprocess
import sys
def test_missing_multiprocessing(tmp_path):
"""
Test that import joblib works even if _multiprocessing is missing.
pytest has already imported everything from joblib. The most reasonable way
to test importing joblib with modified environment is to invoke a separate
Python process. This also ensures that we don't break other tests by
importing a bad `_multiprocessing` module.
"""
(tmp_path / "_multiprocessing.py").write_text(
'raise ImportError("No _multiprocessing module!")'
)
env = dict(os.environ)
# For subprocess, use current sys.path with our custom version of
# multiprocessing inserted.
env["PYTHONPATH"] = ":".join([str(tmp_path)] + sys.path)
subprocess.check_call(
[
sys.executable,
"-c",
"import joblib, math; "
"joblib.Parallel(n_jobs=1)("
"joblib.delayed(math.sqrt)(i**2) for i in range(10))",
],
env=env,
)

View File

@@ -0,0 +1,55 @@
import sys
import joblib
from joblib.test.common import with_multiprocessing
from joblib.testing import check_subprocess_call
def test_version():
assert hasattr(joblib, "__version__"), (
"There are no __version__ argument on the joblib module"
)
@with_multiprocessing
def test_no_start_method_side_effect_on_import():
# check that importing joblib does not implicitly set the global
# start_method for multiprocessing.
code = """if True:
import joblib
import multiprocessing as mp
# The following line would raise RuntimeError if the
# start_method is already set.
mp.set_start_method("loky")
"""
check_subprocess_call([sys.executable, "-c", code])
@with_multiprocessing
def test_no_semaphore_tracker_on_import():
# check that importing joblib does not implicitly spawn a resource tracker
# or a semaphore tracker
code = """if True:
import joblib
from multiprocessing import semaphore_tracker
# The following line would raise RuntimeError if the
# start_method is already set.
msg = "multiprocessing.semaphore_tracker has been spawned on import"
assert semaphore_tracker._semaphore_tracker._fd is None, msg"""
if sys.version_info >= (3, 8):
# semaphore_tracker was renamed in Python 3.8:
code = code.replace("semaphore_tracker", "resource_tracker")
check_subprocess_call([sys.executable, "-c", code])
@with_multiprocessing
def test_no_resource_tracker_on_import():
code = """if True:
import joblib
from joblib.externals.loky.backend import resource_tracker
# The following line would raise RuntimeError if the
# start_method is already set.
msg = "loky.resource_tracker has been spawned on import"
assert resource_tracker._resource_tracker._fd is None, msg
"""
check_subprocess_call([sys.executable, "-c", code])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
"""Test the old numpy pickler, compatibility version."""
# numpy_pickle is not a drop-in replacement of pickle, as it takes
# filenames instead of open files as arguments.
from joblib import numpy_pickle_compat
def test_z_file(tmpdir):
# Test saving and loading data with Zfiles.
filename = tmpdir.join("test.pkl").strpath
data = numpy_pickle_compat.asbytes("Foo, \n Bar, baz, \n\nfoobar")
with open(filename, "wb") as f:
numpy_pickle_compat.write_zfile(f, data)
with open(filename, "rb") as f:
data_read = numpy_pickle_compat.read_zfile(f)
assert data == data_read

View File

@@ -0,0 +1,9 @@
from joblib.compressor import BinaryZlibFile
from joblib.testing import parametrize
@parametrize("filename", ["test", "test"]) # testing str and unicode names
def test_binary_zlib_file(tmpdir, filename):
"""Testing creation of files depending on the type of the filenames."""
binary_file = BinaryZlibFile(tmpdir.join(filename).strpath, mode="wb")
binary_file.close()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
try:
# Python 2.7: use the C pickle to speed up
# test_concurrency_safe_write which pickles big python objects
import cPickle as cpickle
except ImportError:
import pickle as cpickle
import functools
import time
from pickle import PicklingError
import pytest
from joblib import Parallel, delayed
from joblib._store_backends import (
CacheWarning,
FileSystemStoreBackend,
concurrency_safe_write,
)
from joblib.backports import concurrency_safe_rename
from joblib.test.common import with_multiprocessing
from joblib.testing import parametrize, timeout
def write_func(output, filename):
with open(filename, "wb") as f:
cpickle.dump(output, f)
def load_func(expected, filename):
for i in range(10):
try:
with open(filename, "rb") as f:
reloaded = cpickle.load(f)
break
except (OSError, IOError):
# On Windows you can have WindowsError ([Error 5] Access
# is denied or [Error 13] Permission denied) when reading the file,
# probably because a writer process has a lock on the file
time.sleep(0.1)
else:
raise
assert expected == reloaded
def concurrency_safe_write_rename(to_write, filename, write_func):
temporary_filename = concurrency_safe_write(to_write, filename, write_func)
concurrency_safe_rename(temporary_filename, filename)
@timeout(0) # No timeout as this test can be long
@with_multiprocessing
@parametrize("backend", ["multiprocessing", "loky", "threading"])
def test_concurrency_safe_write(tmpdir, backend):
# Add one item to cache
filename = tmpdir.join("test.pkl").strpath
obj = {str(i): i for i in range(int(1e5))}
funcs = [
functools.partial(concurrency_safe_write_rename, write_func=write_func)
if i % 3 != 2
else load_func
for i in range(12)
]
Parallel(n_jobs=2, backend=backend)(delayed(func)(obj, filename) for func in funcs)
def test_warning_on_dump_failure(tmpdir):
# Check that a warning is raised when the dump fails for any reason but
# a PicklingError.
class UnpicklableObject(object):
def __reduce__(self):
raise RuntimeError("some exception")
backend = FileSystemStoreBackend()
backend.location = tmpdir.join("test_warning_on_pickling_error").strpath
backend.compress = None
with pytest.warns(CacheWarning, match="some exception"):
backend.dump_item("testpath", UnpicklableObject())
def test_warning_on_pickling_error(tmpdir):
# This is separate from test_warning_on_dump_failure because in the
# future we will turn this into an exception.
class UnpicklableObject(object):
def __reduce__(self):
raise PicklingError("not picklable")
backend = FileSystemStoreBackend()
backend.location = tmpdir.join("test_warning_on_pickling_error").strpath
backend.compress = None
with pytest.warns(FutureWarning, match="not picklable"):
backend.dump_item("testpath", UnpicklableObject())

View File

@@ -0,0 +1,87 @@
import re
import sys
from joblib.testing import check_subprocess_call, raises
def test_check_subprocess_call():
code = "\n".join(
["result = 1 + 2 * 3", "print(result)", "my_list = [1, 2, 3]", "print(my_list)"]
)
check_subprocess_call([sys.executable, "-c", code])
# Now checking stdout with a regex
check_subprocess_call(
[sys.executable, "-c", code],
# Regex needed for platform-specific line endings
stdout_regex=r"7\s{1,2}\[1, 2, 3\]",
)
def test_check_subprocess_call_non_matching_regex():
code = "42"
non_matching_pattern = "_no_way_this_matches_anything_"
with raises(ValueError) as excinfo:
check_subprocess_call(
[sys.executable, "-c", code], stdout_regex=non_matching_pattern
)
excinfo.match("Unexpected stdout.+{}".format(non_matching_pattern))
def test_check_subprocess_call_wrong_command():
wrong_command = "_a_command_that_does_not_exist_"
with raises(OSError):
check_subprocess_call([wrong_command])
def test_check_subprocess_call_non_zero_return_code():
code_with_non_zero_exit = "\n".join(
[
"import sys",
'print("writing on stdout")',
'sys.stderr.write("writing on stderr")',
"sys.exit(123)",
]
)
pattern = re.compile(
"Non-zero return code: 123.+"
"Stdout:\nwriting on stdout.+"
"Stderr:\nwriting on stderr",
re.DOTALL,
)
with raises(ValueError) as excinfo:
check_subprocess_call([sys.executable, "-c", code_with_non_zero_exit])
excinfo.match(pattern)
def test_check_subprocess_call_timeout():
code_timing_out = "\n".join(
[
"import time",
"import sys",
'print("before sleep on stdout")',
"sys.stdout.flush()",
'sys.stderr.write("before sleep on stderr")',
"sys.stderr.flush()",
# We need to sleep for at least 2 * timeout seconds in case the SIGKILL
# is triggered.
"time.sleep(10)",
'print("process should have be killed before")',
"sys.stdout.flush()",
]
)
pattern = re.compile(
"Non-zero return code:.+"
"Stdout:\nbefore sleep on stdout\\s+"
"Stderr:\nbefore sleep on stderr",
re.DOTALL,
)
with raises(ValueError) as excinfo:
check_subprocess_call([sys.executable, "-c", code_timing_out], timeout=1)
excinfo.match(pattern)

View File

@@ -0,0 +1,45 @@
import pytest
from joblib._utils import eval_expr
@pytest.mark.parametrize(
"expr",
[
"exec('import os')",
"print(1)",
"import os",
"1+1; import os",
"1^1",
"' ' * 10**10",
"9. ** 10000.",
],
)
def test_eval_expr_invalid(expr):
with pytest.raises(ValueError, match="is not a valid or supported arithmetic"):
eval_expr(expr)
def test_eval_expr_too_long():
expr = "1" + "+1" * 50
with pytest.raises(ValueError, match="is too long"):
eval_expr(expr)
@pytest.mark.parametrize("expr", ["1e7", "10**7", "9**9**9"])
def test_eval_expr_too_large_literal(expr):
with pytest.raises(ValueError, match="Numeric literal .* is too large"):
eval_expr(expr)
@pytest.mark.parametrize(
"expr, result",
[
("2*6", 12),
("2**6", 64),
("1 + 2*3**(4) / (6 + -7)", -161.0),
("(20 // 3) % 5", 1),
],
)
def test_eval_expr_valid(expr, result):
assert eval_expr(expr) == result

View File

@@ -0,0 +1,9 @@
def return_slice_of_data(arr, start_idx, end_idx):
return arr[start_idx:end_idx]
def print_filename_and_raise(arr):
from joblib._memmapping_reducer import _get_backing_memmap
print(_get_backing_memmap(arr).filename)
raise ValueError