Initial commit
This commit is contained in:
84
backend/venv/Lib/site-packages/joblib/test/common.py
Normal file
84
backend/venv/Lib/site-packages/joblib/test/common.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Small utilities for testing.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import sysconfig
|
||||
|
||||
from joblib._multiprocessing_helpers import mp
|
||||
from joblib.testing import SkipTest, skipif
|
||||
|
||||
try:
|
||||
import lz4
|
||||
except ImportError:
|
||||
lz4 = None
|
||||
|
||||
# TODO straight removal since in joblib.test.common?
|
||||
IS_PYPY = hasattr(sys, "pypy_version_info")
|
||||
IS_GIL_DISABLED = (
|
||||
sysconfig.get_config_var("Py_GIL_DISABLED") and not sys._is_gil_enabled()
|
||||
)
|
||||
|
||||
# A decorator to run tests only when numpy is available
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
def with_numpy(func):
|
||||
"""A decorator to skip tests requiring numpy."""
|
||||
return func
|
||||
|
||||
except ImportError:
|
||||
|
||||
def with_numpy(func):
|
||||
"""A decorator to skip tests requiring numpy."""
|
||||
|
||||
def my_func():
|
||||
raise SkipTest("Test requires numpy")
|
||||
|
||||
return my_func
|
||||
|
||||
np = None
|
||||
|
||||
# TODO: Turn this back on after refactoring yield based tests in test_hashing
|
||||
# with_numpy = skipif(not np, reason='Test requires numpy.')
|
||||
|
||||
# we use memory_profiler library for memory consumption checks
|
||||
try:
|
||||
from memory_profiler import memory_usage
|
||||
|
||||
def with_memory_profiler(func):
|
||||
"""A decorator to skip tests requiring memory_profiler."""
|
||||
return func
|
||||
|
||||
def memory_used(func, *args, **kwargs):
|
||||
"""Compute memory usage when executing func."""
|
||||
gc.collect()
|
||||
mem_use = memory_usage((func, args, kwargs), interval=0.001)
|
||||
return max(mem_use) - min(mem_use)
|
||||
|
||||
except ImportError:
|
||||
|
||||
def with_memory_profiler(func):
|
||||
"""A decorator to skip tests requiring memory_profiler."""
|
||||
|
||||
def dummy_func():
|
||||
raise SkipTest("Test requires memory_profiler.")
|
||||
|
||||
return dummy_func
|
||||
|
||||
memory_usage = memory_used = None
|
||||
|
||||
|
||||
with_multiprocessing = skipif(mp is None, reason="Needs multiprocessing to run.")
|
||||
|
||||
|
||||
with_dev_shm = skipif(
|
||||
not os.path.exists("/dev/shm"),
|
||||
reason="This test requires a large /dev/shm shared memory fs.",
|
||||
)
|
||||
|
||||
with_lz4 = skipif(lz4 is None, reason="Needs lz4 compression to run")
|
||||
|
||||
without_lz4 = skipif(lz4 is not None, reason="Needs lz4 not being installed to run")
|
||||
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
This script is used to generate test data for joblib/test/test_numpy_pickle.py
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
# pytest needs to be able to import this module even when numpy is
|
||||
# not installed
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
np = None
|
||||
|
||||
import joblib
|
||||
|
||||
|
||||
def get_joblib_version(joblib_version=joblib.__version__):
|
||||
"""Normalize joblib version by removing suffix.
|
||||
|
||||
>>> get_joblib_version('0.8.4')
|
||||
'0.8.4'
|
||||
>>> get_joblib_version('0.8.4b1')
|
||||
'0.8.4'
|
||||
>>> get_joblib_version('0.9.dev0')
|
||||
'0.9'
|
||||
"""
|
||||
matches = [re.match(r"(\d+).*", each) for each in joblib_version.split(".")]
|
||||
return ".".join([m.group(1) for m in matches if m is not None])
|
||||
|
||||
|
||||
def write_test_pickle(to_pickle, args):
|
||||
kwargs = {}
|
||||
compress = args.compress
|
||||
method = args.method
|
||||
joblib_version = get_joblib_version()
|
||||
py_version = "{0[0]}{0[1]}".format(sys.version_info)
|
||||
numpy_version = "".join(np.__version__.split(".")[:2])
|
||||
|
||||
# The game here is to generate the right filename according to the options.
|
||||
body = "_compressed" if (compress and method == "zlib") else ""
|
||||
if compress:
|
||||
if method == "zlib":
|
||||
kwargs["compress"] = True
|
||||
extension = ".gz"
|
||||
else:
|
||||
kwargs["compress"] = (method, 3)
|
||||
extension = ".pkl.{}".format(method)
|
||||
if args.cache_size:
|
||||
kwargs["cache_size"] = 0
|
||||
body += "_cache_size"
|
||||
else:
|
||||
extension = ".pkl"
|
||||
|
||||
pickle_filename = "joblib_{}{}_pickle_py{}_np{}{}".format(
|
||||
joblib_version, body, py_version, numpy_version, extension
|
||||
)
|
||||
|
||||
try:
|
||||
joblib.dump(to_pickle, pickle_filename, **kwargs)
|
||||
except Exception as e:
|
||||
# With old python version (=< 3.3.), we can arrive there when
|
||||
# dumping compressed pickle with LzmaFile.
|
||||
print(
|
||||
"Error: cannot generate file '{}' with arguments '{}'. "
|
||||
"Error was: {}".format(pickle_filename, kwargs, e)
|
||||
)
|
||||
else:
|
||||
print("File '{}' generated successfully.".format(pickle_filename))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Joblib pickle data generator.")
|
||||
parser.add_argument(
|
||||
"--cache_size",
|
||||
action="store_true",
|
||||
help="Force creation of companion numpy files for pickled arrays.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compress", action="store_true", help="Generate compress pickles."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="zlib",
|
||||
choices=["zlib", "gzip", "bz2", "xz", "lzma", "lz4"],
|
||||
help="Set compression method.",
|
||||
)
|
||||
# We need to be specific about dtypes in particular endianness
|
||||
# because the pickles can be generated on one architecture and
|
||||
# the tests run on another one. See
|
||||
# https://github.com/joblib/joblib/issues/279.
|
||||
to_pickle = [
|
||||
np.arange(5, dtype=np.dtype("<i8")),
|
||||
np.arange(5, dtype=np.dtype("<f8")),
|
||||
np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
|
||||
# all possible bytes as a byte string
|
||||
np.arange(256, dtype=np.uint8).tobytes(),
|
||||
np.matrix([0, 1, 2], dtype=np.dtype("<i8")),
|
||||
# unicode string with non-ascii chars
|
||||
"C'est l'\xe9t\xe9 !",
|
||||
]
|
||||
|
||||
write_test_pickle(to_pickle, parser.parse_args())
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
35
backend/venv/Lib/site-packages/joblib/test/test_backports.py
Normal file
35
backend/venv/Lib/site-packages/joblib/test/test_backports.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import mmap
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
from joblib.backports import concurrency_safe_rename, make_memmap
|
||||
from joblib.test.common import with_numpy
|
||||
from joblib.testing import parametrize
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_memmap(tmpdir):
|
||||
fname = tmpdir.join("test.mmap").strpath
|
||||
size = 5 * mmap.ALLOCATIONGRANULARITY
|
||||
offset = mmap.ALLOCATIONGRANULARITY + 1
|
||||
memmap_obj = make_memmap(fname, shape=size, mode="w+", offset=offset)
|
||||
assert memmap_obj.offset == offset
|
||||
|
||||
|
||||
@parametrize("dst_content", [None, "dst content"])
|
||||
@parametrize("backend", [None, "threading"])
|
||||
def test_concurrency_safe_rename(tmpdir, dst_content, backend):
|
||||
src_paths = [tmpdir.join("src_%d" % i) for i in range(4)]
|
||||
for src_path in src_paths:
|
||||
src_path.write("src content")
|
||||
dst_path = tmpdir.join("dst")
|
||||
if dst_content is not None:
|
||||
dst_path.write(dst_content)
|
||||
|
||||
Parallel(n_jobs=4, backend=backend)(
|
||||
delayed(concurrency_safe_rename)(src_path.strpath, dst_path.strpath)
|
||||
for src_path in src_paths
|
||||
)
|
||||
assert dst_path.exists()
|
||||
assert dst_path.read() == "src content"
|
||||
for src_path in src_paths:
|
||||
assert not src_path.exists()
|
||||
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Test that our implementation of wrap_non_picklable_objects mimics
|
||||
properly the loky implementation.
|
||||
"""
|
||||
|
||||
from .._cloudpickle_wrapper import (
|
||||
_my_wrap_non_picklable_objects,
|
||||
wrap_non_picklable_objects,
|
||||
)
|
||||
|
||||
|
||||
def a_function(x):
|
||||
return x
|
||||
|
||||
|
||||
class AClass(object):
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def test_wrap_non_picklable_objects():
|
||||
# Mostly a smoke test: test that we can use callable in the same way
|
||||
# with both our implementation of wrap_non_picklable_objects and the
|
||||
# upstream one
|
||||
for obj in (a_function, AClass()):
|
||||
wrapped_obj = wrap_non_picklable_objects(obj)
|
||||
my_wrapped_obj = _my_wrap_non_picklable_objects(obj)
|
||||
assert wrapped_obj(1) == my_wrapped_obj(1)
|
||||
157
backend/venv/Lib/site-packages/joblib/test/test_config.py
Normal file
157
backend/venv/Lib/site-packages/joblib/test/test_config.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import os
|
||||
|
||||
from joblib._parallel_backends import (
|
||||
LokyBackend,
|
||||
MultiprocessingBackend,
|
||||
ThreadingBackend,
|
||||
)
|
||||
from joblib.parallel import (
|
||||
BACKENDS,
|
||||
DEFAULT_BACKEND,
|
||||
EXTERNAL_BACKENDS,
|
||||
Parallel,
|
||||
delayed,
|
||||
parallel_backend,
|
||||
parallel_config,
|
||||
)
|
||||
from joblib.test.common import np, with_multiprocessing, with_numpy
|
||||
from joblib.test.test_parallel import check_memmap
|
||||
from joblib.testing import parametrize, raises
|
||||
|
||||
|
||||
@parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_global_parallel_backend(context):
|
||||
default = Parallel()._backend
|
||||
|
||||
pb = context("threading")
|
||||
try:
|
||||
assert isinstance(Parallel()._backend, ThreadingBackend)
|
||||
finally:
|
||||
pb.unregister()
|
||||
assert type(Parallel()._backend) is type(default)
|
||||
|
||||
|
||||
@parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_external_backends(context):
|
||||
def register_foo():
|
||||
BACKENDS["foo"] = ThreadingBackend
|
||||
|
||||
EXTERNAL_BACKENDS["foo"] = register_foo
|
||||
try:
|
||||
with context("foo"):
|
||||
assert isinstance(Parallel()._backend, ThreadingBackend)
|
||||
finally:
|
||||
del EXTERNAL_BACKENDS["foo"]
|
||||
|
||||
|
||||
@with_numpy
|
||||
@with_multiprocessing
|
||||
def test_parallel_config_no_backend(tmpdir):
|
||||
# Check that parallel_config allows to change the config
|
||||
# even if no backend is set.
|
||||
with parallel_config(n_jobs=2, max_nbytes=1, temp_folder=tmpdir):
|
||||
with Parallel(prefer="processes") as p:
|
||||
assert isinstance(p._backend, LokyBackend)
|
||||
assert p.n_jobs == 2
|
||||
|
||||
# Checks that memmapping is enabled
|
||||
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2)
|
||||
assert len(os.listdir(tmpdir)) > 0
|
||||
|
||||
|
||||
@with_numpy
|
||||
@with_multiprocessing
|
||||
def test_parallel_config_params_explicit_set(tmpdir):
|
||||
with parallel_config(n_jobs=3, max_nbytes=1, temp_folder=tmpdir):
|
||||
with Parallel(n_jobs=2, prefer="processes", max_nbytes="1M") as p:
|
||||
assert isinstance(p._backend, LokyBackend)
|
||||
assert p.n_jobs == 2
|
||||
|
||||
# Checks that memmapping is disabled
|
||||
with raises(TypeError, match="Expected np.memmap instance"):
|
||||
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2)
|
||||
|
||||
|
||||
@parametrize("param", ["prefer", "require"])
|
||||
def test_parallel_config_bad_params(param):
|
||||
# Check that an error is raised when setting a wrong backend
|
||||
# hint or constraint
|
||||
with raises(ValueError, match=f"{param}=wrong is not a valid"):
|
||||
with parallel_config(**{param: "wrong"}):
|
||||
Parallel()
|
||||
|
||||
|
||||
def test_parallel_config_constructor_params():
|
||||
# Check that an error is raised when backend is None
|
||||
# but backend constructor params are given
|
||||
with raises(ValueError, match="only supported when backend is not None"):
|
||||
with parallel_config(inner_max_num_threads=1):
|
||||
pass
|
||||
|
||||
with raises(ValueError, match="only supported when backend is not None"):
|
||||
with parallel_config(backend_param=1):
|
||||
pass
|
||||
|
||||
with raises(ValueError, match="only supported when backend is a string"):
|
||||
with parallel_config(backend=BACKENDS[DEFAULT_BACKEND], backend_param=1):
|
||||
pass
|
||||
|
||||
|
||||
def test_parallel_config_nested():
|
||||
# Check that nested configuration retrieves the info from the
|
||||
# parent config and do not reset them.
|
||||
|
||||
with parallel_config(n_jobs=2):
|
||||
p = Parallel()
|
||||
assert isinstance(p._backend, BACKENDS[DEFAULT_BACKEND])
|
||||
assert p.n_jobs == 2
|
||||
|
||||
with parallel_config(backend="threading"):
|
||||
with parallel_config(n_jobs=2):
|
||||
p = Parallel()
|
||||
assert isinstance(p._backend, ThreadingBackend)
|
||||
assert p.n_jobs == 2
|
||||
|
||||
with parallel_config(verbose=100):
|
||||
with parallel_config(n_jobs=2):
|
||||
p = Parallel()
|
||||
assert p.verbose == 100
|
||||
assert p.n_jobs == 2
|
||||
|
||||
|
||||
@with_numpy
|
||||
@with_multiprocessing
|
||||
@parametrize(
|
||||
"backend",
|
||||
["multiprocessing", "threading", MultiprocessingBackend(), ThreadingBackend()],
|
||||
)
|
||||
@parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_threadpool_limitation_in_child_context_error(context, backend):
|
||||
with raises(AssertionError, match=r"does not acc.*inner_max_num_threads"):
|
||||
context(backend, inner_max_num_threads=1)
|
||||
|
||||
|
||||
@parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_parallel_n_jobs_none(context):
|
||||
# Check that n_jobs=None is interpreted as "unset" in Parallel
|
||||
# non regression test for #1473
|
||||
with context(backend="threading", n_jobs=2):
|
||||
with Parallel(n_jobs=None) as p:
|
||||
assert p.n_jobs == 2
|
||||
|
||||
with context(backend="threading"):
|
||||
default_n_jobs = Parallel().n_jobs
|
||||
with Parallel(n_jobs=None) as p:
|
||||
assert p.n_jobs == default_n_jobs
|
||||
|
||||
|
||||
@parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_parallel_config_n_jobs_none(context):
|
||||
# Check that n_jobs=None is interpreted as "explicitly set" in
|
||||
# parallel_(config/backend)
|
||||
# non regression test for #1473
|
||||
with context(backend="threading", n_jobs=2):
|
||||
with context(backend="threading", n_jobs=None):
|
||||
# n_jobs=None resets n_jobs to backend's default
|
||||
with Parallel() as p:
|
||||
assert p.n_jobs == 1
|
||||
607
backend/venv/Lib/site-packages/joblib/test/test_dask.py
Normal file
607
backend/venv/Lib/site-packages/joblib/test/test_dask.py
Normal file
@@ -0,0 +1,607 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from random import random
|
||||
from time import sleep
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import Parallel, delayed, parallel_backend, parallel_config
|
||||
from .._dask import DaskDistributedBackend
|
||||
from ..parallel import AutoBatchingMixin, ThreadingBackend
|
||||
from .common import np, with_numpy
|
||||
from .test_parallel import (
|
||||
_recursive_backend_info,
|
||||
_test_deadlock_with_generator,
|
||||
_test_parallel_unordered_generator_returns_fastest_first, # noqa: E501
|
||||
)
|
||||
|
||||
distributed = pytest.importorskip("distributed")
|
||||
dask = pytest.importorskip("dask")
|
||||
|
||||
# These imports need to be after the pytest.importorskip hence the noqa: E402
|
||||
from distributed import Client, LocalCluster, get_client # noqa: E402
|
||||
from distributed.metrics import time # noqa: E402
|
||||
|
||||
# Note: pytest requires to manually import all fixtures used in the test
|
||||
# and their dependencies.
|
||||
from distributed.utils_test import cleanup, cluster, inc # noqa: E402, F401
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def avoid_dask_env_leaks(tmp_path):
|
||||
# when starting a dask nanny, the environment variable might change.
|
||||
# this fixture makes sure the environment is reset after the test.
|
||||
|
||||
from joblib._parallel_backends import ParallelBackendBase
|
||||
|
||||
old_value = {k: os.environ.get(k) for k in ParallelBackendBase.MAX_NUM_THREADS_VARS}
|
||||
yield
|
||||
|
||||
# Reset the environment variables to their original values
|
||||
for k, v in old_value.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def slow_raise_value_error(condition, duration=0.05):
|
||||
sleep(duration)
|
||||
if condition:
|
||||
raise ValueError("condition evaluated to True")
|
||||
|
||||
|
||||
def count_events(event_name, client):
|
||||
worker_events = client.run(lambda dask_worker: dask_worker.log)
|
||||
event_counts = {}
|
||||
for w, events in worker_events.items():
|
||||
event_counts[w] = len(
|
||||
[event for event in list(events) if event[1] == event_name]
|
||||
)
|
||||
return event_counts
|
||||
|
||||
|
||||
def test_simple(loop):
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask"):
|
||||
seq = Parallel()(delayed(inc)(i) for i in range(10))
|
||||
assert seq == [inc(i) for i in range(10)]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Parallel()(
|
||||
delayed(slow_raise_value_error)(i == 3) for i in range(10)
|
||||
)
|
||||
|
||||
seq = Parallel()(delayed(inc)(i) for i in range(10))
|
||||
assert seq == [inc(i) for i in range(10)]
|
||||
|
||||
|
||||
def test_dask_backend_uses_autobatching(loop):
|
||||
assert (
|
||||
DaskDistributedBackend.compute_batch_size
|
||||
is AutoBatchingMixin.compute_batch_size
|
||||
)
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask"):
|
||||
with Parallel() as parallel:
|
||||
# The backend should be initialized with a default
|
||||
# batch size of 1:
|
||||
backend = parallel._backend
|
||||
assert isinstance(backend, DaskDistributedBackend)
|
||||
assert backend.parallel is parallel
|
||||
assert backend._effective_batch_size == 1
|
||||
|
||||
# Launch many short tasks that should trigger
|
||||
# auto-batching:
|
||||
parallel(delayed(lambda: None)() for _ in range(int(1e4)))
|
||||
assert backend._effective_batch_size > 10
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_jobs", [2, -1])
|
||||
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_parallel_unordered_generator_returns_fastest_first_with_dask(n_jobs, context):
|
||||
with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
|
||||
_test_parallel_unordered_generator_returns_fastest_first(None, n_jobs)
|
||||
|
||||
|
||||
@with_numpy
|
||||
@pytest.mark.parametrize("n_jobs", [2, -1])
|
||||
@pytest.mark.parametrize("return_as", ["generator", "generator_unordered"])
|
||||
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_deadlock_with_generator_and_dask(context, return_as, n_jobs):
|
||||
with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
|
||||
_test_deadlock_with_generator(None, return_as, n_jobs)
|
||||
|
||||
|
||||
@with_numpy
|
||||
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_nested_parallelism_with_dask(context):
|
||||
with distributed.Client(n_workers=2, threads_per_worker=2):
|
||||
# 10 MB of data as argument to trigger implicit scattering
|
||||
data = np.ones(int(1e7), dtype=np.uint8)
|
||||
for i in range(2):
|
||||
with context("dask"):
|
||||
backend_types_and_levels = _recursive_backend_info(data=data)
|
||||
assert len(backend_types_and_levels) == 4
|
||||
assert all(
|
||||
name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
|
||||
)
|
||||
|
||||
# No argument
|
||||
with context("dask"):
|
||||
backend_types_and_levels = _recursive_backend_info()
|
||||
assert len(backend_types_and_levels) == 4
|
||||
assert all(
|
||||
name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
|
||||
)
|
||||
|
||||
|
||||
def random2():
|
||||
return random()
|
||||
|
||||
|
||||
def test_dont_assume_function_purity(loop):
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask"):
|
||||
x, y = Parallel()(delayed(random2)() for i in range(2))
|
||||
assert x != y
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mixed", [True, False])
|
||||
def test_dask_funcname(loop, mixed):
|
||||
from joblib._dask import Batch
|
||||
|
||||
if not mixed:
|
||||
tasks = [delayed(inc)(i) for i in range(4)]
|
||||
batch_repr = "batch_of_inc_4_calls"
|
||||
else:
|
||||
tasks = [delayed(abs)(i) if i % 2 else delayed(inc)(i) for i in range(4)]
|
||||
batch_repr = "mixed_batch_of_inc_4_calls"
|
||||
|
||||
assert repr(Batch(tasks)) == batch_repr
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client:
|
||||
with parallel_config(backend="dask"):
|
||||
_ = Parallel(batch_size=2, pre_dispatch="all")(tasks)
|
||||
|
||||
def f(dask_scheduler):
|
||||
return list(dask_scheduler.transition_log)
|
||||
|
||||
batch_repr = batch_repr.replace("4", "2")
|
||||
log = client.run_on_scheduler(f)
|
||||
assert all("batch_of_inc" in tup[0] for tup in log)
|
||||
|
||||
|
||||
def test_no_undesired_distributed_cache_hit():
|
||||
# Dask has a pickle cache for callables that are called many times. Because
|
||||
# the dask backends used to wrap both the functions and the arguments
|
||||
# under instances of the Batch callable class this caching mechanism could
|
||||
# lead to bugs as described in: https://github.com/joblib/joblib/pull/1055
|
||||
# The joblib-dask backend has been refactored to avoid bundling the
|
||||
# arguments as an attribute of the Batch instance to avoid this problem.
|
||||
# This test serves as non-regression problem.
|
||||
|
||||
# Use a large number of input arguments to give the AutoBatchingMixin
|
||||
# enough tasks to kick-in.
|
||||
lists = [[] for _ in range(100)]
|
||||
np = pytest.importorskip("numpy")
|
||||
X = np.arange(int(1e6))
|
||||
|
||||
def isolated_operation(list_, data=None):
|
||||
if data is not None:
|
||||
np.testing.assert_array_equal(data, X)
|
||||
list_.append(uuid4().hex)
|
||||
return list_
|
||||
|
||||
cluster = LocalCluster(n_workers=1, threads_per_worker=2)
|
||||
client = Client(cluster)
|
||||
try:
|
||||
with parallel_config(backend="dask"):
|
||||
# dispatches joblib.parallel.BatchedCalls
|
||||
res = Parallel()(delayed(isolated_operation)(list_) for list_ in lists)
|
||||
|
||||
# The original arguments should not have been mutated as the mutation
|
||||
# happens in the dask worker process.
|
||||
assert lists == [[] for _ in range(100)]
|
||||
|
||||
# Here we did not pass any large numpy array as argument to
|
||||
# isolated_operation so no scattering event should happen under the
|
||||
# hood.
|
||||
counts = count_events("receive-from-scatter", client)
|
||||
assert sum(counts.values()) == 0
|
||||
assert all([len(r) == 1 for r in res])
|
||||
|
||||
with parallel_config(backend="dask"):
|
||||
# Append a large array which will be scattered by dask, and
|
||||
# dispatch joblib._dask.Batch
|
||||
res = Parallel()(
|
||||
delayed(isolated_operation)(list_, data=X) for list_ in lists
|
||||
)
|
||||
|
||||
# This time, auto-scattering should have kicked it.
|
||||
counts = count_events("receive-from-scatter", client)
|
||||
assert sum(counts.values()) > 0
|
||||
assert all([len(r) == 1 for r in res])
|
||||
finally:
|
||||
client.close(timeout=30)
|
||||
cluster.close(timeout=30)
|
||||
|
||||
|
||||
class CountSerialized(object):
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
self.count = 0
|
||||
|
||||
def __add__(self, other):
|
||||
return self.x + getattr(other, "x", other)
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def __reduce__(self):
|
||||
self.count += 1
|
||||
return (CountSerialized, (self.x,))
|
||||
|
||||
|
||||
def add5(a, b, c, d=0, e=0):
|
||||
return a + b + c + d + e
|
||||
|
||||
|
||||
def test_manual_scatter(loop):
|
||||
# Let's check that the number of times scattered and non-scattered
|
||||
# variables are serialized is consistent between `joblib.Parallel` calls
|
||||
# and equivalent native `client.submit` call.
|
||||
|
||||
# Number of serializations can vary from dask to another, so this test only
|
||||
# checks that `joblib.Parallel` does not add more serialization steps than
|
||||
# a native `client.submit` call, but does not check for an exact number of
|
||||
# serialization steps.
|
||||
|
||||
w, x, y, z = (CountSerialized(i) for i in range(4))
|
||||
|
||||
f = delayed(add5)
|
||||
tasks = [f(x, y, z, d=4, e=5) for _ in range(10)]
|
||||
tasks += [
|
||||
f(x, z, y, d=5, e=4),
|
||||
f(y, x, z, d=x, e=5),
|
||||
f(z, z, x, d=z, e=y),
|
||||
]
|
||||
expected = [func(*args, **kwargs) for func, args, kwargs in tasks]
|
||||
|
||||
with cluster() as (s, _):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask", scatter=[w, x, y]):
|
||||
results_parallel = Parallel(batch_size=1)(tasks)
|
||||
assert results_parallel == expected
|
||||
|
||||
# Check that an error is raised for bad arguments, as scatter must
|
||||
# take a list/tuple
|
||||
with pytest.raises(TypeError):
|
||||
with parallel_config(backend="dask", loop=loop, scatter=1):
|
||||
pass
|
||||
|
||||
# Scattered variables only serialized during scatter. Checking with an
|
||||
# extra variable as this count can vary from one dask version
|
||||
# to another.
|
||||
n_serialization_scatter_with_parallel = w.count
|
||||
assert x.count == n_serialization_scatter_with_parallel
|
||||
assert y.count == n_serialization_scatter_with_parallel
|
||||
n_serialization_with_parallel = z.count
|
||||
|
||||
# Reset the cluster and the serialization count
|
||||
for var in (w, x, y, z):
|
||||
var.count = 0
|
||||
|
||||
with cluster() as (s, _):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
scattered = dict()
|
||||
for obj in w, x, y:
|
||||
scattered[id(obj)] = client.scatter(obj, broadcast=True)
|
||||
results_native = [
|
||||
client.submit(
|
||||
func,
|
||||
*(scattered.get(id(arg), arg) for arg in args),
|
||||
**dict(
|
||||
(key, scattered.get(id(value), value))
|
||||
for (key, value) in kwargs.items()
|
||||
),
|
||||
key=str(uuid4()),
|
||||
).result()
|
||||
for (func, args, kwargs) in tasks
|
||||
]
|
||||
assert results_native == expected
|
||||
|
||||
# Now check that the number of serialization steps is the same for joblib
|
||||
# and native dask calls.
|
||||
n_serialization_scatter_native = w.count
|
||||
assert x.count == n_serialization_scatter_native
|
||||
assert y.count == n_serialization_scatter_native
|
||||
|
||||
assert n_serialization_scatter_with_parallel == n_serialization_scatter_native
|
||||
|
||||
distributed_version = tuple(int(v) for v in distributed.__version__.split("."))
|
||||
if distributed_version < (2023, 4):
|
||||
# Previous to 2023.4, the serialization was adding an extra call to
|
||||
# __reduce__ for the last job `f(z, z, x, d=z, e=y)`, because `z`
|
||||
# appears both in the args and kwargs, which is not the case when
|
||||
# running with joblib. Cope with this discrepancy.
|
||||
assert z.count == n_serialization_with_parallel + 1
|
||||
else:
|
||||
assert z.count == n_serialization_with_parallel
|
||||
|
||||
|
||||
# When the same IOLoop is used for multiple clients in a row, use
|
||||
# loop_in_thread instead of loop to prevent the Client from closing it. See
|
||||
# dask/distributed #4112
|
||||
def test_auto_scatter(loop_in_thread):
|
||||
np = pytest.importorskip("numpy")
|
||||
data1 = np.ones(int(1e4), dtype=np.uint8)
|
||||
data2 = np.ones(int(1e4), dtype=np.uint8)
|
||||
data_to_process = ([data1] * 3) + ([data2] * 3)
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop_in_thread) as client:
|
||||
with parallel_config(backend="dask"):
|
||||
# Passing the same data as arg and kwarg triggers a single
|
||||
# scatter operation whose result is reused.
|
||||
Parallel()(
|
||||
delayed(noop)(data, data, i, opt=data)
|
||||
for i, data in enumerate(data_to_process)
|
||||
)
|
||||
# By default large array are automatically scattered with
|
||||
# broadcast=1 which means that one worker must directly receive
|
||||
# the data from the scatter operation once.
|
||||
counts = count_events("receive-from-scatter", client)
|
||||
assert counts[a["address"]] + counts[b["address"]] == 2
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop_in_thread) as client:
|
||||
with parallel_config(backend="dask"):
|
||||
Parallel()(delayed(noop)(data1[:3], i) for i in range(5))
|
||||
# Small arrays are passed within the task definition without going
|
||||
# through a scatter operation.
|
||||
counts = count_events("receive-from-scatter", client)
|
||||
assert counts[a["address"]] == 0
|
||||
assert counts[b["address"]] == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retry_no", list(range(2)))
|
||||
def test_nested_scatter(loop, retry_no):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
NUM_INNER_TASKS = 10
|
||||
NUM_OUTER_TASKS = 10
|
||||
|
||||
def my_sum(x, i, j):
|
||||
return np.sum(x)
|
||||
|
||||
def outer_function_joblib(array, i):
|
||||
client = get_client() # noqa
|
||||
with parallel_config(backend="dask"):
|
||||
results = Parallel()(
|
||||
delayed(my_sum)(array[j:], i, j) for j in range(NUM_INNER_TASKS)
|
||||
)
|
||||
return sum(results)
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as _:
|
||||
with parallel_config(backend="dask"):
|
||||
my_array = np.ones(10000)
|
||||
_ = Parallel()(
|
||||
delayed(outer_function_joblib)(my_array[i:], i)
|
||||
for i in range(NUM_OUTER_TASKS)
|
||||
)
|
||||
|
||||
|
||||
def test_nested_backend_context_manager(loop_in_thread):
|
||||
def get_nested_pids():
|
||||
pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
|
||||
pids |= set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
|
||||
return pids
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop_in_thread) as client:
|
||||
with parallel_config(backend="dask"):
|
||||
pid_groups = Parallel(n_jobs=2)(
|
||||
delayed(get_nested_pids)() for _ in range(10)
|
||||
)
|
||||
for pid_group in pid_groups:
|
||||
assert len(set(pid_group)) <= 2
|
||||
|
||||
# No deadlocks
|
||||
with Client(s["address"], loop=loop_in_thread) as client: # noqa: F841
|
||||
with parallel_config(backend="dask"):
|
||||
pid_groups = Parallel(n_jobs=2)(
|
||||
delayed(get_nested_pids)() for _ in range(10)
|
||||
)
|
||||
for pid_group in pid_groups:
|
||||
assert len(set(pid_group)) <= 2
|
||||
|
||||
|
||||
def test_nested_backend_context_manager_implicit_n_jobs(loop):
|
||||
# Check that Parallel with no explicit n_jobs value automatically selects
|
||||
# all the dask workers, including in nested calls.
|
||||
|
||||
def _backend_type(p):
|
||||
return p._backend.__class__.__name__
|
||||
|
||||
def get_nested_implicit_n_jobs():
|
||||
with Parallel() as p:
|
||||
return _backend_type(p), p.n_jobs
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask"):
|
||||
with Parallel() as p:
|
||||
assert _backend_type(p) == "DaskDistributedBackend"
|
||||
assert p.n_jobs == -1
|
||||
all_nested_n_jobs = p(
|
||||
delayed(get_nested_implicit_n_jobs)() for _ in range(2)
|
||||
)
|
||||
for backend_type, nested_n_jobs in all_nested_n_jobs:
|
||||
assert backend_type == "DaskDistributedBackend"
|
||||
assert nested_n_jobs == -1
|
||||
|
||||
|
||||
def test_errors(loop):
|
||||
with pytest.raises(ValueError) as info:
|
||||
with parallel_config(backend="dask"):
|
||||
pass
|
||||
|
||||
assert "create a dask client" in str(info.value).lower()
|
||||
|
||||
|
||||
def test_correct_nested_backend(loop):
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
# No requirement, should be us
|
||||
with parallel_config(backend="dask"):
|
||||
result = Parallel(n_jobs=2)(
|
||||
delayed(outer)(nested_require=None) for _ in range(1)
|
||||
)
|
||||
assert isinstance(result[0][0][0], DaskDistributedBackend)
|
||||
|
||||
# Require threads, should be threading
|
||||
with parallel_config(backend="dask"):
|
||||
result = Parallel(n_jobs=2)(
|
||||
delayed(outer)(nested_require="sharedmem") for _ in range(1)
|
||||
)
|
||||
assert isinstance(result[0][0][0], ThreadingBackend)
|
||||
|
||||
|
||||
def outer(nested_require):
|
||||
return Parallel(n_jobs=2, prefer="threads")(
|
||||
delayed(middle)(nested_require) for _ in range(1)
|
||||
)
|
||||
|
||||
|
||||
def middle(require):
|
||||
return Parallel(n_jobs=2, require=require)(delayed(inner)() for _ in range(1))
|
||||
|
||||
|
||||
def inner():
|
||||
return Parallel()._backend
|
||||
|
||||
|
||||
def test_secede_with_no_processes(loop):
|
||||
# https://github.com/dask/distributed/issues/1775
|
||||
with Client(loop=loop, processes=False, set_as_default=True):
|
||||
with parallel_config(backend="dask"):
|
||||
Parallel(n_jobs=4)(delayed(id)(i) for i in range(2))
|
||||
|
||||
|
||||
def _worker_address(_):
|
||||
from distributed import get_worker
|
||||
|
||||
return get_worker().address
|
||||
|
||||
|
||||
def test_dask_backend_keywords(loop):
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask", workers=a["address"]):
|
||||
seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
|
||||
assert seq == [a["address"]] * 10
|
||||
|
||||
with parallel_config(backend="dask", workers=b["address"]):
|
||||
seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
|
||||
assert seq == [b["address"]] * 10
|
||||
|
||||
|
||||
def test_scheduler_tasks_cleanup(loop):
|
||||
with Client(processes=False, loop=loop) as client:
|
||||
with parallel_config(backend="dask"):
|
||||
Parallel()(delayed(inc)(i) for i in range(10))
|
||||
|
||||
start = time()
|
||||
while client.cluster.scheduler.tasks:
|
||||
sleep(0.01)
|
||||
assert time() < start + 5
|
||||
|
||||
assert not client.futures
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cluster_strategy", ["adaptive", "late_scaling"])
|
||||
@pytest.mark.skipif(
|
||||
distributed.__version__ <= "2.1.1" and distributed.__version__ >= "1.28.0",
|
||||
reason="distributed bug - https://github.com/dask/distributed/pull/2841",
|
||||
)
|
||||
def test_wait_for_workers(cluster_strategy):
|
||||
cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
|
||||
client = Client(cluster)
|
||||
if cluster_strategy == "adaptive":
|
||||
cluster.adapt(minimum=0, maximum=2)
|
||||
elif cluster_strategy == "late_scaling":
|
||||
# Tell the cluster to start workers but this is a non-blocking call
|
||||
# and new workers might take time to connect. In this case the Parallel
|
||||
# call should wait for at least one worker to come up before starting
|
||||
# to schedule work.
|
||||
cluster.scale(2)
|
||||
try:
|
||||
with parallel_config(backend="dask"):
|
||||
# The following should wait a bit for at least one worker to
|
||||
# become available.
|
||||
Parallel()(delayed(inc)(i) for i in range(10))
|
||||
finally:
|
||||
client.close()
|
||||
cluster.close()
|
||||
|
||||
|
||||
def test_wait_for_workers_timeout():
|
||||
# Start a cluster with 0 worker:
|
||||
cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
|
||||
client = Client(cluster)
|
||||
try:
|
||||
with parallel_config(backend="dask", wait_for_workers_timeout=0.1):
|
||||
# Short timeout: DaskDistributedBackend
|
||||
msg = "DaskDistributedBackend has no worker after 0.1 seconds."
|
||||
with pytest.raises(TimeoutError, match=msg):
|
||||
Parallel()(delayed(inc)(i) for i in range(10))
|
||||
|
||||
with parallel_config(backend="dask", wait_for_workers_timeout=0):
|
||||
# No timeout: fallback to generic joblib failure:
|
||||
msg = "DaskDistributedBackend has no active worker"
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
Parallel()(delayed(inc)(i) for i in range(10))
|
||||
finally:
|
||||
client.close()
|
||||
cluster.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["loky", "multiprocessing"])
|
||||
def test_joblib_warning_inside_dask_daemonic_worker(backend):
|
||||
cluster = LocalCluster(n_workers=2)
|
||||
client = Client(cluster)
|
||||
try:
|
||||
|
||||
def func_using_joblib_parallel():
|
||||
# Somehow trying to check the warning type here (e.g. with
|
||||
# pytest.warns(UserWarning)) make the test hang. Work-around:
|
||||
# return the warning record to the client and the warning check is
|
||||
# done client-side.
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
Parallel(n_jobs=2, backend=backend)(delayed(inc)(i) for i in range(10))
|
||||
|
||||
return record
|
||||
|
||||
fut = client.submit(func_using_joblib_parallel)
|
||||
record = fut.result()
|
||||
|
||||
assert len(record) == 1
|
||||
warning = record[0].message
|
||||
assert isinstance(warning, UserWarning)
|
||||
assert "distributed.worker.daemon" in str(warning)
|
||||
finally:
|
||||
client.close(timeout=30)
|
||||
cluster.close(timeout=30)
|
||||
80
backend/venv/Lib/site-packages/joblib/test/test_disk.py
Normal file
80
backend/venv/Lib/site-packages/joblib/test/test_disk.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Unit tests for the disk utilities.
|
||||
"""
|
||||
|
||||
# Authors: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
|
||||
# Lars Buitinck
|
||||
# Copyright (c) 2010 Gael Varoquaux
|
||||
# License: BSD Style, 3 clauses.
|
||||
|
||||
from __future__ import with_statement
|
||||
|
||||
import array
|
||||
import os
|
||||
|
||||
from joblib.disk import disk_used, memstr_to_bytes, mkdirp, rm_subdirs
|
||||
from joblib.testing import parametrize, raises
|
||||
|
||||
###############################################################################
|
||||
|
||||
|
||||
def test_disk_used(tmpdir):
|
||||
cachedir = tmpdir.strpath
|
||||
# Not write a file that is 1M big in this directory, and check the
|
||||
# size. The reason we use such a big file is that it makes us robust
|
||||
# to errors due to block allocation.
|
||||
a = array.array("i")
|
||||
sizeof_i = a.itemsize
|
||||
target_size = 1024
|
||||
n = int(target_size * 1024 / sizeof_i)
|
||||
a = array.array("i", n * (1,))
|
||||
with open(os.path.join(cachedir, "test"), "wb") as output:
|
||||
a.tofile(output)
|
||||
assert disk_used(cachedir) >= target_size
|
||||
assert disk_used(cachedir) < target_size + 12
|
||||
|
||||
|
||||
@parametrize(
|
||||
"text,value",
|
||||
[
|
||||
("80G", 80 * 1024**3),
|
||||
("1.4M", int(1.4 * 1024**2)),
|
||||
("120M", 120 * 1024**2),
|
||||
("53K", 53 * 1024),
|
||||
],
|
||||
)
|
||||
def test_memstr_to_bytes(text, value):
|
||||
assert memstr_to_bytes(text) == value
|
||||
|
||||
|
||||
@parametrize(
|
||||
"text,exception,regex",
|
||||
[
|
||||
("fooG", ValueError, r"Invalid literal for size.*fooG.*"),
|
||||
("1.4N", ValueError, r"Invalid literal for size.*1.4N.*"),
|
||||
],
|
||||
)
|
||||
def test_memstr_to_bytes_exception(text, exception, regex):
|
||||
with raises(exception) as excinfo:
|
||||
memstr_to_bytes(text)
|
||||
assert excinfo.match(regex)
|
||||
|
||||
|
||||
def test_mkdirp(tmpdir):
|
||||
mkdirp(os.path.join(tmpdir.strpath, "ham"))
|
||||
mkdirp(os.path.join(tmpdir.strpath, "ham"))
|
||||
mkdirp(os.path.join(tmpdir.strpath, "spam", "spam"))
|
||||
|
||||
# Not all OSErrors are ignored
|
||||
with raises(OSError):
|
||||
mkdirp("")
|
||||
|
||||
|
||||
def test_rm_subdirs(tmpdir):
|
||||
sub_path = os.path.join(tmpdir.strpath, "subdir_one", "subdir_two")
|
||||
full_path = os.path.join(sub_path, "subdir_three")
|
||||
mkdirp(os.path.join(full_path))
|
||||
|
||||
rm_subdirs(sub_path)
|
||||
assert os.path.exists(sub_path)
|
||||
assert not os.path.exists(full_path)
|
||||
338
backend/venv/Lib/site-packages/joblib/test/test_func_inspect.py
Normal file
338
backend/venv/Lib/site-packages/joblib/test/test_func_inspect.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Test the func_inspect module.
|
||||
"""
|
||||
|
||||
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
|
||||
# Copyright (c) 2009 Gael Varoquaux
|
||||
# License: BSD Style, 3 clauses.
|
||||
|
||||
import functools
|
||||
|
||||
from joblib.func_inspect import (
|
||||
_clean_win_chars,
|
||||
filter_args,
|
||||
format_signature,
|
||||
get_func_code,
|
||||
get_func_name,
|
||||
)
|
||||
from joblib.memory import Memory
|
||||
from joblib.test.common import with_numpy
|
||||
from joblib.testing import fixture, parametrize, raises
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Module-level functions and fixture, for tests
|
||||
def f(x, y=0):
|
||||
pass
|
||||
|
||||
|
||||
def g(x):
|
||||
pass
|
||||
|
||||
|
||||
def h(x, y=0, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def i(x=1):
|
||||
pass
|
||||
|
||||
|
||||
def j(x, y, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def k(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def m1(x, *, y):
|
||||
pass
|
||||
|
||||
|
||||
def m2(x, *, y, z=3):
|
||||
pass
|
||||
|
||||
|
||||
@fixture(scope="module")
|
||||
def cached_func(tmpdir_factory):
|
||||
# Create a Memory object to test decorated functions.
|
||||
# We should be careful not to call the decorated functions, so that
|
||||
# cache directories are not created in the temp dir.
|
||||
cachedir = tmpdir_factory.mktemp("joblib_test_func_inspect")
|
||||
mem = Memory(cachedir.strpath)
|
||||
|
||||
@mem.cache
|
||||
def cached_func_inner(x):
|
||||
return x
|
||||
|
||||
return cached_func_inner
|
||||
|
||||
|
||||
class Klass(object):
|
||||
def f(self, x):
|
||||
return x
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Tests
|
||||
|
||||
|
||||
@parametrize(
|
||||
"func,args,filtered_args",
|
||||
[
|
||||
(f, [[], (1,)], {"x": 1, "y": 0}),
|
||||
(f, [["x"], (1,)], {"y": 0}),
|
||||
(f, [["y"], (0,)], {"x": 0}),
|
||||
(f, [["y"], (0,), {"y": 1}], {"x": 0}),
|
||||
(f, [["x", "y"], (0,)], {}),
|
||||
(f, [[], (0,), {"y": 1}], {"x": 0, "y": 1}),
|
||||
(f, [["y"], (), {"x": 2, "y": 1}], {"x": 2}),
|
||||
(g, [[], (), {"x": 1}], {"x": 1}),
|
||||
(i, [[], (2,)], {"x": 2}),
|
||||
],
|
||||
)
|
||||
def test_filter_args(func, args, filtered_args):
|
||||
assert filter_args(func, *args) == filtered_args
|
||||
|
||||
|
||||
def test_filter_args_method():
|
||||
obj = Klass()
|
||||
assert filter_args(obj.f, [], (1,)) == {"x": 1, "self": obj}
|
||||
|
||||
|
||||
@parametrize(
|
||||
"func,args,filtered_args",
|
||||
[
|
||||
(h, [[], (1,)], {"x": 1, "y": 0, "*": [], "**": {}}),
|
||||
(h, [[], (1, 2, 3, 4)], {"x": 1, "y": 2, "*": [3, 4], "**": {}}),
|
||||
(h, [[], (1, 25), {"ee": 2}], {"x": 1, "y": 25, "*": [], "**": {"ee": 2}}),
|
||||
(h, [["*"], (1, 2, 25), {"ee": 2}], {"x": 1, "y": 2, "**": {"ee": 2}}),
|
||||
],
|
||||
)
|
||||
def test_filter_varargs(func, args, filtered_args):
|
||||
assert filter_args(func, *args) == filtered_args
|
||||
|
||||
|
||||
test_filter_kwargs_extra_params = [
|
||||
(m1, [[], (1,), {"y": 2}], {"x": 1, "y": 2}),
|
||||
(m2, [[], (1,), {"y": 2}], {"x": 1, "y": 2, "z": 3}),
|
||||
]
|
||||
|
||||
|
||||
@parametrize(
|
||||
"func,args,filtered_args",
|
||||
[
|
||||
(k, [[], (1, 2), {"ee": 2}], {"*": [1, 2], "**": {"ee": 2}}),
|
||||
(k, [[], (3, 4)], {"*": [3, 4], "**": {}}),
|
||||
]
|
||||
+ test_filter_kwargs_extra_params,
|
||||
)
|
||||
def test_filter_kwargs(func, args, filtered_args):
|
||||
assert filter_args(func, *args) == filtered_args
|
||||
|
||||
|
||||
def test_filter_args_2():
|
||||
assert filter_args(j, [], (1, 2), {"ee": 2}) == {"x": 1, "y": 2, "**": {"ee": 2}}
|
||||
|
||||
ff = functools.partial(f, 1)
|
||||
# filter_args has to special-case partial
|
||||
assert filter_args(ff, [], (1,)) == {"*": [1], "**": {}}
|
||||
assert filter_args(ff, ["y"], (1,)) == {"*": [1], "**": {}}
|
||||
|
||||
|
||||
@parametrize("func,funcname", [(f, "f"), (g, "g"), (cached_func, "cached_func")])
|
||||
def test_func_name(func, funcname):
|
||||
# Check that we are not confused by decoration
|
||||
# here testcase 'cached_func' is the function itself
|
||||
assert get_func_name(func)[1] == funcname
|
||||
|
||||
|
||||
def test_func_name_on_inner_func(cached_func):
|
||||
# Check that we are not confused by decoration
|
||||
# here testcase 'cached_func' is the 'cached_func_inner' function
|
||||
# returned by 'cached_func' fixture
|
||||
assert get_func_name(cached_func)[1] == "cached_func_inner"
|
||||
|
||||
|
||||
def test_func_name_collision_on_inner_func():
|
||||
# Check that two functions defining and caching an inner function
|
||||
# with the same do not cause (module, name) collision
|
||||
def f():
|
||||
def inner_func():
|
||||
return # pragma: no cover
|
||||
|
||||
return get_func_name(inner_func)
|
||||
|
||||
def g():
|
||||
def inner_func():
|
||||
return # pragma: no cover
|
||||
|
||||
return get_func_name(inner_func)
|
||||
|
||||
module, name = f()
|
||||
other_module, other_name = g()
|
||||
|
||||
assert name == other_name
|
||||
assert module != other_module
|
||||
|
||||
|
||||
def test_func_inspect_errors():
|
||||
# Check that func_inspect is robust and will work on weird objects
|
||||
assert get_func_name("a".lower)[-1] == "lower"
|
||||
assert get_func_code("a".lower)[1:] == (None, -1)
|
||||
ff = lambda x: x # noqa: E731
|
||||
assert get_func_name(ff, win_characters=False)[-1] == "<lambda>"
|
||||
assert get_func_code(ff)[1] == __file__.replace(".pyc", ".py")
|
||||
# Simulate a function defined in __main__
|
||||
ff.__module__ = "__main__"
|
||||
assert get_func_name(ff, win_characters=False)[-1] == "<lambda>"
|
||||
assert get_func_code(ff)[1] == __file__.replace(".pyc", ".py")
|
||||
|
||||
|
||||
def func_with_kwonly_args(a, b, *, kw1="kw1", kw2="kw2"):
|
||||
pass
|
||||
|
||||
|
||||
def func_with_signature(a: int, b: int) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_filter_args_edge_cases():
|
||||
assert filter_args(func_with_kwonly_args, [], (1, 2), {"kw1": 3, "kw2": 4}) == {
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
"kw1": 3,
|
||||
"kw2": 4,
|
||||
}
|
||||
|
||||
# filter_args doesn't care about keyword-only arguments so you
|
||||
# can pass 'kw1' into *args without any problem
|
||||
with raises(ValueError) as excinfo:
|
||||
filter_args(func_with_kwonly_args, [], (1, 2, 3), {"kw2": 2})
|
||||
excinfo.match("Keyword-only parameter 'kw1' was passed as positional parameter")
|
||||
|
||||
assert filter_args(
|
||||
func_with_kwonly_args, ["b", "kw2"], (1, 2), {"kw1": 3, "kw2": 4}
|
||||
) == {"a": 1, "kw1": 3}
|
||||
|
||||
assert filter_args(func_with_signature, ["b"], (1, 2)) == {"a": 1}
|
||||
|
||||
|
||||
def test_bound_methods():
|
||||
"""Make sure that calling the same method on two different instances
|
||||
of the same class does resolv to different signatures.
|
||||
"""
|
||||
a = Klass()
|
||||
b = Klass()
|
||||
assert filter_args(a.f, [], (1,)) != filter_args(b.f, [], (1,))
|
||||
|
||||
|
||||
@parametrize(
|
||||
"exception,regex,func,args",
|
||||
[
|
||||
(
|
||||
ValueError,
|
||||
"ignore_lst must be a list of parameters to ignore",
|
||||
f,
|
||||
["bar", (None,)],
|
||||
),
|
||||
(
|
||||
ValueError,
|
||||
r"Ignore list: argument \'(.*)\' is not defined",
|
||||
g,
|
||||
[["bar"], (None,)],
|
||||
),
|
||||
(ValueError, "Wrong number of arguments", h, [[]]),
|
||||
],
|
||||
)
|
||||
def test_filter_args_error_msg(exception, regex, func, args):
|
||||
"""Make sure that filter_args returns decent error messages, for the
|
||||
sake of the user.
|
||||
"""
|
||||
with raises(exception) as excinfo:
|
||||
filter_args(func, *args)
|
||||
excinfo.match(regex)
|
||||
|
||||
|
||||
def test_filter_args_no_kwargs_mutation():
|
||||
"""None-regression test against 0.12.0 changes.
|
||||
|
||||
https://github.com/joblib/joblib/pull/75
|
||||
|
||||
Make sure filter args doesn't mutate the kwargs dict that gets passed in.
|
||||
"""
|
||||
kwargs = {"x": 0}
|
||||
filter_args(g, [], [], kwargs)
|
||||
assert kwargs == {"x": 0}
|
||||
|
||||
|
||||
def test_clean_win_chars():
|
||||
string = r"C:\foo\bar\main.py"
|
||||
mangled_string = _clean_win_chars(string)
|
||||
for char in ("\\", ":", "<", ">", "!"):
|
||||
assert char not in mangled_string
|
||||
|
||||
|
||||
@parametrize(
|
||||
"func,args,kwargs,sgn_expected",
|
||||
[
|
||||
(g, [list(range(5))], {}, "g([0, 1, 2, 3, 4])"),
|
||||
(k, [1, 2, (3, 4)], {"y": True}, "k(1, 2, (3, 4), y=True)"),
|
||||
],
|
||||
)
|
||||
def test_format_signature(func, args, kwargs, sgn_expected):
|
||||
# Test signature formatting.
|
||||
path, sgn_result = format_signature(func, *args, **kwargs)
|
||||
assert sgn_result == sgn_expected
|
||||
|
||||
|
||||
def test_format_signature_long_arguments():
|
||||
shortening_threshold = 1500
|
||||
# shortening gets it down to 700 characters but there is the name
|
||||
# of the function in the signature and a few additional things
|
||||
# like dots for the ellipsis
|
||||
shortening_target = 700 + 10
|
||||
|
||||
arg = "a" * shortening_threshold
|
||||
_, signature = format_signature(h, arg)
|
||||
assert len(signature) < shortening_target
|
||||
|
||||
nb_args = 5
|
||||
args = [arg for _ in range(nb_args)]
|
||||
_, signature = format_signature(h, *args)
|
||||
assert len(signature) < shortening_target * nb_args
|
||||
|
||||
kwargs = {str(i): arg for i, arg in enumerate(args)}
|
||||
_, signature = format_signature(h, **kwargs)
|
||||
assert len(signature) < shortening_target * nb_args
|
||||
|
||||
_, signature = format_signature(h, *args, **kwargs)
|
||||
assert len(signature) < shortening_target * 2 * nb_args
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_format_signature_numpy():
|
||||
"""Test the format signature formatting with numpy."""
|
||||
|
||||
|
||||
def test_special_source_encoding():
|
||||
from joblib.test.test_func_inspect_special_encoding import big5_f
|
||||
|
||||
func_code, source_file, first_line = get_func_code(big5_f)
|
||||
assert first_line == 5
|
||||
assert "def big5_f():" in func_code
|
||||
assert "test_func_inspect_special_encoding" in source_file
|
||||
|
||||
|
||||
def _get_code():
|
||||
from joblib.test.test_func_inspect_special_encoding import big5_f
|
||||
|
||||
return get_func_code(big5_f)[0]
|
||||
|
||||
|
||||
def test_func_code_consistency():
|
||||
from joblib.parallel import Parallel, delayed
|
||||
|
||||
codes = Parallel(n_jobs=2)(delayed(_get_code)() for _ in range(5))
|
||||
assert len(set(codes)) == 1
|
||||
@@ -0,0 +1,9 @@
|
||||
# -*- coding: big5 -*-
|
||||
|
||||
|
||||
# Some Traditional Chinese characters: 一些中文字符
|
||||
def big5_f():
|
||||
"""用於測試的函數
|
||||
"""
|
||||
# 註釋
|
||||
return 0
|
||||
520
backend/venv/Lib/site-packages/joblib/test/test_hashing.py
Normal file
520
backend/venv/Lib/site-packages/joblib/test/test_hashing.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""
|
||||
Test the hashing module.
|
||||
"""
|
||||
|
||||
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
|
||||
# Copyright (c) 2009 Gael Varoquaux
|
||||
# License: BSD Style, 3 clauses.
|
||||
|
||||
import collections
|
||||
import gc
|
||||
import hashlib
|
||||
import io
|
||||
import itertools
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from decimal import Decimal
|
||||
|
||||
from joblib.func_inspect import filter_args
|
||||
from joblib.hashing import hash
|
||||
from joblib.memory import Memory
|
||||
from joblib.test.common import np, with_numpy
|
||||
from joblib.testing import fixture, parametrize, raises, skipif
|
||||
|
||||
|
||||
def unicode(s):
|
||||
return s
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Helper functions for the tests
|
||||
def time_func(func, *args):
|
||||
"""Time function func on *args."""
|
||||
times = list()
|
||||
for _ in range(3):
|
||||
t1 = time.time()
|
||||
func(*args)
|
||||
times.append(time.time() - t1)
|
||||
return min(times)
|
||||
|
||||
|
||||
def relative_time(func1, func2, *args):
|
||||
"""Return the relative time between func1 and func2 applied on
|
||||
*args.
|
||||
"""
|
||||
time_func1 = time_func(func1, *args)
|
||||
time_func2 = time_func(func2, *args)
|
||||
relative_diff = 0.5 * (abs(time_func1 - time_func2) / (time_func1 + time_func2))
|
||||
return relative_diff
|
||||
|
||||
|
||||
class Klass(object):
|
||||
def f(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class KlassWithCachedMethod(object):
|
||||
def __init__(self, cachedir):
|
||||
mem = Memory(location=cachedir)
|
||||
self.f = mem.cache(self.f)
|
||||
|
||||
def f(self, x):
|
||||
return x
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Tests
|
||||
|
||||
input_list = [
|
||||
1,
|
||||
2,
|
||||
1.0,
|
||||
2.0,
|
||||
1 + 1j,
|
||||
2.0 + 1j,
|
||||
"a",
|
||||
"b",
|
||||
(1,),
|
||||
(
|
||||
1,
|
||||
1,
|
||||
),
|
||||
[
|
||||
1,
|
||||
],
|
||||
[
|
||||
1,
|
||||
1,
|
||||
],
|
||||
{1: 1},
|
||||
{1: 2},
|
||||
{2: 1},
|
||||
None,
|
||||
gc.collect,
|
||||
[
|
||||
1,
|
||||
].append,
|
||||
# Next 2 sets have unorderable elements in python 3.
|
||||
set(("a", 1)),
|
||||
set(("a", 1, ("a", 1))),
|
||||
# Next 2 dicts have unorderable type of keys in python 3.
|
||||
{"a": 1, 1: 2},
|
||||
{"a": 1, 1: 2, "d": {"a": 1}},
|
||||
]
|
||||
|
||||
|
||||
@parametrize("obj1", input_list)
|
||||
@parametrize("obj2", input_list)
|
||||
def test_trivial_hash(obj1, obj2):
|
||||
"""Smoke test hash on various types."""
|
||||
# Check that 2 objects have the same hash only if they are the same.
|
||||
are_hashes_equal = hash(obj1) == hash(obj2)
|
||||
are_objs_identical = obj1 is obj2
|
||||
assert are_hashes_equal == are_objs_identical
|
||||
|
||||
|
||||
def test_hash_methods():
|
||||
# Check that hashing instance methods works
|
||||
a = io.StringIO(unicode("a"))
|
||||
assert hash(a.flush) == hash(a.flush)
|
||||
a1 = collections.deque(range(10))
|
||||
a2 = collections.deque(range(9))
|
||||
assert hash(a1.extend) != hash(a2.extend)
|
||||
|
||||
|
||||
@fixture(scope="function")
|
||||
@with_numpy
|
||||
def three_np_arrays():
|
||||
rnd = np.random.RandomState(0)
|
||||
arr1 = rnd.random_sample((10, 10))
|
||||
arr2 = arr1.copy()
|
||||
arr3 = arr2.copy()
|
||||
arr3[0] += 1
|
||||
return arr1, arr2, arr3
|
||||
|
||||
|
||||
def test_hash_numpy_arrays(three_np_arrays):
|
||||
arr1, arr2, arr3 = three_np_arrays
|
||||
|
||||
for obj1, obj2 in itertools.product(three_np_arrays, repeat=2):
|
||||
are_hashes_equal = hash(obj1) == hash(obj2)
|
||||
are_arrays_equal = np.all(obj1 == obj2)
|
||||
assert are_hashes_equal == are_arrays_equal
|
||||
|
||||
assert hash(arr1) != hash(arr1.T)
|
||||
|
||||
|
||||
def test_hash_numpy_dict_of_arrays(three_np_arrays):
|
||||
arr1, arr2, arr3 = three_np_arrays
|
||||
|
||||
d1 = {1: arr1, 2: arr2}
|
||||
d2 = {1: arr2, 2: arr1}
|
||||
d3 = {1: arr2, 2: arr3}
|
||||
|
||||
assert hash(d1) == hash(d2)
|
||||
assert hash(d1) != hash(d3)
|
||||
|
||||
|
||||
@with_numpy
|
||||
@parametrize("dtype", ["datetime64[s]", "timedelta64[D]"])
|
||||
def test_numpy_datetime_array(dtype):
|
||||
# memoryview is not supported for some dtypes e.g. datetime64
|
||||
# see https://github.com/joblib/joblib/issues/188 for more details
|
||||
a_hash = hash(np.arange(10))
|
||||
array = np.arange(0, 10, dtype=dtype)
|
||||
assert hash(array) != a_hash
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_hash_numpy_noncontiguous():
|
||||
a = np.asarray(np.arange(6000).reshape((1000, 2, 3)), order="F")[:, :1, :]
|
||||
b = np.ascontiguousarray(a)
|
||||
assert hash(a) != hash(b)
|
||||
|
||||
c = np.asfortranarray(a)
|
||||
assert hash(a) != hash(c)
|
||||
|
||||
|
||||
@with_numpy
|
||||
@parametrize("coerce_mmap", [True, False])
|
||||
def test_hash_memmap(tmpdir, coerce_mmap):
|
||||
"""Check that memmap and arrays hash identically if coerce_mmap is True."""
|
||||
filename = tmpdir.join("memmap_temp").strpath
|
||||
try:
|
||||
m = np.memmap(filename, shape=(10, 10), mode="w+")
|
||||
a = np.asarray(m)
|
||||
are_hashes_equal = hash(a, coerce_mmap=coerce_mmap) == hash(
|
||||
m, coerce_mmap=coerce_mmap
|
||||
)
|
||||
assert are_hashes_equal == coerce_mmap
|
||||
finally:
|
||||
if "m" in locals():
|
||||
del m
|
||||
# Force a garbage-collection cycle, to be certain that the
|
||||
# object is delete, and we don't run in a problem under
|
||||
# Windows with a file handle still open.
|
||||
gc.collect()
|
||||
|
||||
|
||||
@with_numpy
|
||||
@skipif(
|
||||
sys.platform == "win32",
|
||||
reason="This test is not stable under windows for some reason",
|
||||
)
|
||||
def test_hash_numpy_performance():
|
||||
"""Check the performance of hashing numpy arrays:
|
||||
|
||||
In [22]: a = np.random.random(1000000)
|
||||
|
||||
In [23]: %timeit hashlib.md5(a).hexdigest()
|
||||
100 loops, best of 3: 20.7 ms per loop
|
||||
|
||||
In [24]: %timeit hashlib.md5(pickle.dumps(a, protocol=2)).hexdigest()
|
||||
1 loops, best of 3: 73.1 ms per loop
|
||||
|
||||
In [25]: %timeit hashlib.md5(cPickle.dumps(a, protocol=2)).hexdigest()
|
||||
10 loops, best of 3: 53.9 ms per loop
|
||||
|
||||
In [26]: %timeit hash(a)
|
||||
100 loops, best of 3: 20.8 ms per loop
|
||||
"""
|
||||
rnd = np.random.RandomState(0)
|
||||
a = rnd.random_sample(1000000)
|
||||
|
||||
def md5_hash(x):
|
||||
return hashlib.md5(memoryview(x)).hexdigest()
|
||||
|
||||
relative_diff = relative_time(md5_hash, hash, a)
|
||||
assert relative_diff < 0.3
|
||||
|
||||
# Check that hashing an tuple of 3 arrays takes approximately
|
||||
# 3 times as much as hashing one array
|
||||
time_hashlib = 3 * time_func(md5_hash, a)
|
||||
time_hash = time_func(hash, (a, a, a))
|
||||
relative_diff = 0.5 * (abs(time_hash - time_hashlib) / (time_hash + time_hashlib))
|
||||
assert relative_diff < 0.3
|
||||
|
||||
|
||||
def test_bound_methods_hash():
|
||||
"""Make sure that calling the same method on two different instances
|
||||
of the same class does resolve to the same hashes.
|
||||
"""
|
||||
a = Klass()
|
||||
b = Klass()
|
||||
assert hash(filter_args(a.f, [], (1,))) == hash(filter_args(b.f, [], (1,)))
|
||||
|
||||
|
||||
def test_bound_cached_methods_hash(tmpdir):
|
||||
"""Make sure that calling the same _cached_ method on two different
|
||||
instances of the same class does resolve to the same hashes.
|
||||
"""
|
||||
a = KlassWithCachedMethod(tmpdir.strpath)
|
||||
b = KlassWithCachedMethod(tmpdir.strpath)
|
||||
assert hash(filter_args(a.f.func, [], (1,))) == hash(
|
||||
filter_args(b.f.func, [], (1,))
|
||||
)
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_hash_object_dtype():
|
||||
"""Make sure that ndarrays with dtype `object' hash correctly."""
|
||||
|
||||
a = np.array([np.arange(i) for i in range(6)], dtype=object)
|
||||
b = np.array([np.arange(i) for i in range(6)], dtype=object)
|
||||
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_numpy_scalar():
|
||||
# Numpy scalars are built from compiled functions, and lead to
|
||||
# strange pickling paths explored, that can give hash collisions
|
||||
a = np.float64(2.0)
|
||||
b = np.float64(3.0)
|
||||
assert hash(a) != hash(b)
|
||||
|
||||
|
||||
def test_dict_hash(tmpdir):
|
||||
# Check that dictionaries hash consistently, even though the ordering
|
||||
# of the keys is not guaranteed
|
||||
k = KlassWithCachedMethod(tmpdir.strpath)
|
||||
|
||||
d = {
|
||||
"#s12069__c_maps.nii.gz": [33],
|
||||
"#s12158__c_maps.nii.gz": [33],
|
||||
"#s12258__c_maps.nii.gz": [33],
|
||||
"#s12277__c_maps.nii.gz": [33],
|
||||
"#s12300__c_maps.nii.gz": [33],
|
||||
"#s12401__c_maps.nii.gz": [33],
|
||||
"#s12430__c_maps.nii.gz": [33],
|
||||
"#s13817__c_maps.nii.gz": [33],
|
||||
"#s13903__c_maps.nii.gz": [33],
|
||||
"#s13916__c_maps.nii.gz": [33],
|
||||
"#s13981__c_maps.nii.gz": [33],
|
||||
"#s13982__c_maps.nii.gz": [33],
|
||||
"#s13983__c_maps.nii.gz": [33],
|
||||
}
|
||||
|
||||
a = k.f(d)
|
||||
b = k.f(a)
|
||||
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
|
||||
def test_set_hash(tmpdir):
|
||||
# Check that sets hash consistently, even though their ordering
|
||||
# is not guaranteed
|
||||
k = KlassWithCachedMethod(tmpdir.strpath)
|
||||
|
||||
s = set(
|
||||
[
|
||||
"#s12069__c_maps.nii.gz",
|
||||
"#s12158__c_maps.nii.gz",
|
||||
"#s12258__c_maps.nii.gz",
|
||||
"#s12277__c_maps.nii.gz",
|
||||
"#s12300__c_maps.nii.gz",
|
||||
"#s12401__c_maps.nii.gz",
|
||||
"#s12430__c_maps.nii.gz",
|
||||
"#s13817__c_maps.nii.gz",
|
||||
"#s13903__c_maps.nii.gz",
|
||||
"#s13916__c_maps.nii.gz",
|
||||
"#s13981__c_maps.nii.gz",
|
||||
"#s13982__c_maps.nii.gz",
|
||||
"#s13983__c_maps.nii.gz",
|
||||
]
|
||||
)
|
||||
|
||||
a = k.f(s)
|
||||
b = k.f(a)
|
||||
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
|
||||
def test_set_decimal_hash():
|
||||
# Check that sets containing decimals hash consistently, even though
|
||||
# ordering is not guaranteed
|
||||
assert hash(set([Decimal(0), Decimal("NaN")])) == hash(
|
||||
set([Decimal("NaN"), Decimal(0)])
|
||||
)
|
||||
|
||||
|
||||
def test_string():
|
||||
# Test that we obtain the same hash for object owning several strings,
|
||||
# whatever the past of these strings (which are immutable in Python)
|
||||
string = "foo"
|
||||
a = {string: "bar"}
|
||||
b = {string: "bar"}
|
||||
c = pickle.loads(pickle.dumps(b))
|
||||
assert hash([a, b]) == hash([a, c])
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_numpy_dtype_pickling():
|
||||
# numpy dtype hashing is tricky to get right: see #231, #239, #251 #1080,
|
||||
# #1082, and explanatory comments inside
|
||||
# ``joblib.hashing.NumpyHasher.save``.
|
||||
|
||||
# In this test, we make sure that the pickling of numpy dtypes is robust to
|
||||
# object identity and object copy.
|
||||
|
||||
dt1 = np.dtype("f4")
|
||||
dt2 = np.dtype("f4")
|
||||
|
||||
# simple dtypes objects are interned
|
||||
assert dt1 is dt2
|
||||
assert hash(dt1) == hash(dt2)
|
||||
|
||||
dt1_roundtripped = pickle.loads(pickle.dumps(dt1))
|
||||
assert dt1 is not dt1_roundtripped
|
||||
assert hash(dt1) == hash(dt1_roundtripped)
|
||||
|
||||
assert hash([dt1, dt1]) == hash([dt1_roundtripped, dt1_roundtripped])
|
||||
assert hash([dt1, dt1]) == hash([dt1, dt1_roundtripped])
|
||||
|
||||
complex_dt1 = np.dtype([("name", np.str_, 16), ("grades", np.float64, (2,))])
|
||||
complex_dt2 = np.dtype([("name", np.str_, 16), ("grades", np.float64, (2,))])
|
||||
|
||||
# complex dtypes objects are not interned
|
||||
assert hash(complex_dt1) == hash(complex_dt2)
|
||||
|
||||
complex_dt1_roundtripped = pickle.loads(pickle.dumps(complex_dt1))
|
||||
assert complex_dt1_roundtripped is not complex_dt1
|
||||
assert hash(complex_dt1) == hash(complex_dt1_roundtripped)
|
||||
|
||||
assert hash([complex_dt1, complex_dt1]) == hash(
|
||||
[complex_dt1_roundtripped, complex_dt1_roundtripped]
|
||||
)
|
||||
assert hash([complex_dt1, complex_dt1]) == hash(
|
||||
[complex_dt1_roundtripped, complex_dt1]
|
||||
)
|
||||
|
||||
|
||||
@parametrize(
|
||||
"to_hash,expected",
|
||||
[
|
||||
("This is a string to hash", "71b3f47df22cb19431d85d92d0b230b2"),
|
||||
("C'est l\xe9t\xe9", "2d8d189e9b2b0b2e384d93c868c0e576"),
|
||||
((123456, 54321, -98765), "e205227dd82250871fa25aa0ec690aa3"),
|
||||
(
|
||||
[random.Random(42).random() for _ in range(5)],
|
||||
"a11ffad81f9682a7d901e6edc3d16c84",
|
||||
),
|
||||
({"abcde": 123, "sadfas": [-9999, 2, 3]}, "aeda150553d4bb5c69f0e69d51b0e2ef"),
|
||||
],
|
||||
)
|
||||
def test_hashes_stay_the_same(to_hash, expected):
|
||||
# We want to make sure that hashes don't change with joblib
|
||||
# version. For end users, that would mean that they have to
|
||||
# regenerate their cache from scratch, which potentially means
|
||||
# lengthy recomputations.
|
||||
# Expected results have been generated with joblib 0.9.2
|
||||
assert hash(to_hash) == expected
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_hashes_are_different_between_c_and_fortran_contiguous_arrays():
|
||||
# We want to be sure that the c-contiguous and f-contiguous versions of the
|
||||
# same array produce 2 different hashes.
|
||||
rng = np.random.RandomState(0)
|
||||
arr_c = rng.random_sample((10, 10))
|
||||
arr_f = np.asfortranarray(arr_c)
|
||||
assert hash(arr_c) != hash(arr_f)
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_0d_array():
|
||||
hash(np.array(0))
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_0d_and_1d_array_hashing_is_different():
|
||||
assert hash(np.array(0)) != hash(np.array([0]))
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_hashes_stay_the_same_with_numpy_objects():
|
||||
# Note: joblib used to test numpy objects hashing by comparing the produced
|
||||
# hash of an object with some hard-coded target value to guarantee that
|
||||
# hashing remains the same across joblib versions. However, since numpy
|
||||
# 1.20 and joblib 1.0, joblib relies on potentially unstable implementation
|
||||
# details of numpy to hash np.dtype objects, which makes the stability of
|
||||
# hash values across different environments hard to guarantee and to test.
|
||||
# As a result, hashing stability across joblib versions becomes best-effort
|
||||
# only, and we only test the consistency within a single environment by
|
||||
# making sure:
|
||||
# - the hash of two copies of the same objects is the same
|
||||
# - hashing some object in two different python processes produces the same
|
||||
# value. This should be viewed as a proxy for testing hash consistency
|
||||
# through time between Python sessions (provided no change in the
|
||||
# environment was done between sessions).
|
||||
|
||||
def create_objects_to_hash():
|
||||
rng = np.random.RandomState(42)
|
||||
# Being explicit about dtypes in order to avoid
|
||||
# architecture-related differences. Also using 'f4' rather than
|
||||
# 'f8' for float arrays because 'f8' arrays generated by
|
||||
# rng.random.randn don't seem to be bit-identical on 32bit and
|
||||
# 64bit machines.
|
||||
to_hash_list = [
|
||||
rng.randint(-1000, high=1000, size=50).astype("<i8"),
|
||||
tuple(rng.randn(3).astype("<f4") for _ in range(5)),
|
||||
[rng.randn(3).astype("<f4") for _ in range(5)],
|
||||
{
|
||||
-3333: rng.randn(3, 5).astype("<f4"),
|
||||
0: [
|
||||
rng.randint(10, size=20).astype("<i8"),
|
||||
rng.randn(10).astype("<f4"),
|
||||
],
|
||||
},
|
||||
# Non regression cases for
|
||||
# https://github.com/joblib/joblib/issues/308
|
||||
np.arange(100, dtype="<i8").reshape((10, 10)),
|
||||
# Fortran contiguous array
|
||||
np.asfortranarray(np.arange(100, dtype="<i8").reshape((10, 10))),
|
||||
# Non contiguous array
|
||||
np.arange(100, dtype="<i8").reshape((10, 10))[:, :2],
|
||||
]
|
||||
return to_hash_list
|
||||
|
||||
# Create two lists containing copies of the same objects. joblib.hash
|
||||
# should return the same hash for to_hash_list_one[i] and
|
||||
# to_hash_list_two[i]
|
||||
to_hash_list_one = create_objects_to_hash()
|
||||
to_hash_list_two = create_objects_to_hash()
|
||||
|
||||
e1 = ProcessPoolExecutor(max_workers=1)
|
||||
e2 = ProcessPoolExecutor(max_workers=1)
|
||||
|
||||
try:
|
||||
for obj_1, obj_2 in zip(to_hash_list_one, to_hash_list_two):
|
||||
# testing consistency of hashes across python processes
|
||||
hash_1 = e1.submit(hash, obj_1).result()
|
||||
hash_2 = e2.submit(hash, obj_1).result()
|
||||
assert hash_1 == hash_2
|
||||
|
||||
# testing consistency when hashing two copies of the same objects.
|
||||
hash_3 = e1.submit(hash, obj_2).result()
|
||||
assert hash_1 == hash_3
|
||||
|
||||
finally:
|
||||
e1.shutdown()
|
||||
e2.shutdown()
|
||||
|
||||
|
||||
def test_hashing_pickling_error():
|
||||
def non_picklable():
|
||||
return 42
|
||||
|
||||
with raises(pickle.PicklingError) as excinfo:
|
||||
hash(non_picklable)
|
||||
excinfo.match("PicklingError while hashing")
|
||||
|
||||
|
||||
def test_wrong_hash_name():
|
||||
msg = "Valid options for 'hash_name' are"
|
||||
with raises(ValueError, match=msg):
|
||||
data = {"foo": "bar"}
|
||||
hash(data, hash_name="invalid")
|
||||
15
backend/venv/Lib/site-packages/joblib/test/test_init.py
Normal file
15
backend/venv/Lib/site-packages/joblib/test/test_init.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Basic test case to test functioning of module's top-level
|
||||
|
||||
try:
|
||||
from joblib import * # noqa
|
||||
|
||||
_top_import_error = None
|
||||
except Exception as ex: # pragma: no cover
|
||||
_top_import_error = ex
|
||||
|
||||
|
||||
def test_import_joblib():
|
||||
# Test either above import has failed for some reason
|
||||
# "import *" only allowed at module level, hence we
|
||||
# rely on setting up the variable above
|
||||
assert _top_import_error is None
|
||||
29
backend/venv/Lib/site-packages/joblib/test/test_logger.py
Normal file
29
backend/venv/Lib/site-packages/joblib/test/test_logger.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Test the logger module.
|
||||
"""
|
||||
|
||||
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
|
||||
# Copyright (c) 2009 Gael Varoquaux
|
||||
# License: BSD Style, 3 clauses.
|
||||
import re
|
||||
|
||||
from joblib.logger import PrintTime
|
||||
|
||||
|
||||
def test_print_time(tmpdir, capsys):
|
||||
# A simple smoke test for PrintTime.
|
||||
logfile = tmpdir.join("test.log").strpath
|
||||
print_time = PrintTime(logfile=logfile)
|
||||
print_time("Foo")
|
||||
# Create a second time, to smoke test log rotation.
|
||||
print_time = PrintTime(logfile=logfile)
|
||||
print_time("Foo")
|
||||
# And a third time
|
||||
print_time = PrintTime(logfile=logfile)
|
||||
print_time("Foo")
|
||||
|
||||
out_printed_text, err_printed_text = capsys.readouterr()
|
||||
# Use regexps to be robust to time variations
|
||||
match = r"Foo: 0\..s, 0\..min\nFoo: 0\..s, 0..min\nFoo: " + r".\..s, 0..min\n"
|
||||
if not re.match(match, err_printed_text):
|
||||
raise AssertionError("Excepted %s, got %s" % (match, err_printed_text))
|
||||
1280
backend/venv/Lib/site-packages/joblib/test/test_memmapping.py
Normal file
1280
backend/venv/Lib/site-packages/joblib/test/test_memmapping.py
Normal file
File diff suppressed because it is too large
Load Diff
1582
backend/venv/Lib/site-packages/joblib/test/test_memory.py
Normal file
1582
backend/venv/Lib/site-packages/joblib/test/test_memory.py
Normal file
File diff suppressed because it is too large
Load Diff
180
backend/venv/Lib/site-packages/joblib/test/test_memory_async.py
Normal file
180
backend/venv/Lib/site-packages/joblib/test/test_memory_async.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import asyncio
|
||||
import gc
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from joblib.memory import (
|
||||
AsyncMemorizedFunc,
|
||||
AsyncNotMemorizedFunc,
|
||||
MemorizedResult,
|
||||
Memory,
|
||||
NotMemorizedResult,
|
||||
)
|
||||
from joblib.test.common import np, with_numpy
|
||||
from joblib.testing import raises
|
||||
|
||||
from .test_memory import corrupt_single_cache_item, monkeypatch_cached_func_warn
|
||||
|
||||
|
||||
async def check_identity_lazy_async(func, accumulator, location):
|
||||
"""Similar to check_identity_lazy_async for coroutine functions"""
|
||||
memory = Memory(location=location, verbose=0)
|
||||
func = memory.cache(func)
|
||||
for i in range(3):
|
||||
for _ in range(2):
|
||||
value = await func(i)
|
||||
assert value == i
|
||||
assert len(accumulator) == i + 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_integration_async(tmpdir):
|
||||
accumulator = list()
|
||||
|
||||
async def f(n):
|
||||
await asyncio.sleep(0.1)
|
||||
accumulator.append(1)
|
||||
return n
|
||||
|
||||
await check_identity_lazy_async(f, accumulator, tmpdir.strpath)
|
||||
|
||||
# Now test clearing
|
||||
for compress in (False, True):
|
||||
for mmap_mode in ("r", None):
|
||||
memory = Memory(
|
||||
location=tmpdir.strpath,
|
||||
verbose=10,
|
||||
mmap_mode=mmap_mode,
|
||||
compress=compress,
|
||||
)
|
||||
# First clear the cache directory, to check that our code can
|
||||
# handle that
|
||||
# NOTE: this line would raise an exception, as the database
|
||||
# file is still open; we ignore the error since we want to
|
||||
# test what happens if the directory disappears
|
||||
shutil.rmtree(tmpdir.strpath, ignore_errors=True)
|
||||
g = memory.cache(f)
|
||||
await g(1)
|
||||
g.clear(warn=False)
|
||||
current_accumulator = len(accumulator)
|
||||
out = await g(1)
|
||||
|
||||
assert len(accumulator) == current_accumulator + 1
|
||||
# Also, check that Memory.eval works similarly
|
||||
evaled = await memory.eval(f, 1)
|
||||
assert evaled == out
|
||||
assert len(accumulator) == current_accumulator + 1
|
||||
|
||||
# Now do a smoke test with a function defined in __main__, as the name
|
||||
# mangling rules are more complex
|
||||
f.__module__ = "__main__"
|
||||
memory = Memory(location=tmpdir.strpath, verbose=0)
|
||||
await memory.cache(f)(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_memory_async():
|
||||
accumulator = list()
|
||||
|
||||
async def ff(x):
|
||||
await asyncio.sleep(0.1)
|
||||
accumulator.append(1)
|
||||
return x
|
||||
|
||||
memory = Memory(location=None, verbose=0)
|
||||
gg = memory.cache(ff)
|
||||
for _ in range(4):
|
||||
current_accumulator = len(accumulator)
|
||||
await gg(1)
|
||||
assert len(accumulator) == current_accumulator + 1
|
||||
|
||||
|
||||
@with_numpy
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_numpy_check_mmap_mode_async(tmpdir, monkeypatch):
|
||||
"""Check that mmap_mode is respected even at the first call"""
|
||||
|
||||
memory = Memory(location=tmpdir.strpath, mmap_mode="r", verbose=0)
|
||||
|
||||
@memory.cache()
|
||||
async def twice(a):
|
||||
return a * 2
|
||||
|
||||
a = np.ones(3)
|
||||
b = await twice(a)
|
||||
c = await twice(a)
|
||||
|
||||
assert isinstance(c, np.memmap)
|
||||
assert c.mode == "r"
|
||||
|
||||
assert isinstance(b, np.memmap)
|
||||
assert b.mode == "r"
|
||||
|
||||
# Corrupts the file, Deleting b and c mmaps
|
||||
# is necessary to be able edit the file
|
||||
del b
|
||||
del c
|
||||
gc.collect()
|
||||
corrupt_single_cache_item(memory)
|
||||
|
||||
# Make sure that corrupting the file causes recomputation and that
|
||||
# a warning is issued.
|
||||
recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
|
||||
d = await twice(a)
|
||||
assert len(recorded_warnings) == 1
|
||||
exception_msg = "Exception while loading results"
|
||||
assert exception_msg in recorded_warnings[0]
|
||||
# Asserts that the recomputation returns a mmap
|
||||
assert isinstance(d, np.memmap)
|
||||
assert d.mode == "r"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_and_shelve_async(tmpdir):
|
||||
async def f(x, y=1):
|
||||
await asyncio.sleep(0.1)
|
||||
return x**2 + y
|
||||
|
||||
# Test MemorizedFunc outputting a reference to cache.
|
||||
for func, Result in zip(
|
||||
(
|
||||
AsyncMemorizedFunc(f, tmpdir.strpath),
|
||||
AsyncNotMemorizedFunc(f),
|
||||
Memory(location=tmpdir.strpath, verbose=0).cache(f),
|
||||
Memory(location=None).cache(f),
|
||||
),
|
||||
(
|
||||
MemorizedResult,
|
||||
NotMemorizedResult,
|
||||
MemorizedResult,
|
||||
NotMemorizedResult,
|
||||
),
|
||||
):
|
||||
for _ in range(2):
|
||||
result = await func.call_and_shelve(2)
|
||||
assert isinstance(result, Result)
|
||||
assert result.get() == 5
|
||||
|
||||
result.clear()
|
||||
with raises(KeyError):
|
||||
result.get()
|
||||
result.clear() # Do nothing if there is no cache.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memorized_func_call_async(memory):
|
||||
async def ff(x, counter):
|
||||
await asyncio.sleep(0.1)
|
||||
counter[x] = counter.get(x, 0) + 1
|
||||
return counter[x]
|
||||
|
||||
gg = memory.cache(ff, ignore=["counter"])
|
||||
|
||||
counter = {}
|
||||
assert await gg(2, counter) == 1
|
||||
assert await gg(2, counter) == 1
|
||||
|
||||
x, meta = await gg.call(2, counter)
|
||||
assert x == 2, "f has not been called properly"
|
||||
assert isinstance(meta, dict), "Metadata are not returned by MemorizedFunc.call."
|
||||
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Pyodide and other single-threaded Python builds will be missing the
|
||||
_multiprocessing module. Test that joblib still works in this environment.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def test_missing_multiprocessing(tmp_path):
|
||||
"""
|
||||
Test that import joblib works even if _multiprocessing is missing.
|
||||
|
||||
pytest has already imported everything from joblib. The most reasonable way
|
||||
to test importing joblib with modified environment is to invoke a separate
|
||||
Python process. This also ensures that we don't break other tests by
|
||||
importing a bad `_multiprocessing` module.
|
||||
"""
|
||||
(tmp_path / "_multiprocessing.py").write_text(
|
||||
'raise ImportError("No _multiprocessing module!")'
|
||||
)
|
||||
env = dict(os.environ)
|
||||
# For subprocess, use current sys.path with our custom version of
|
||||
# multiprocessing inserted.
|
||||
env["PYTHONPATH"] = ":".join([str(tmp_path)] + sys.path)
|
||||
subprocess.check_call(
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
"import joblib, math; "
|
||||
"joblib.Parallel(n_jobs=1)("
|
||||
"joblib.delayed(math.sqrt)(i**2) for i in range(10))",
|
||||
],
|
||||
env=env,
|
||||
)
|
||||
55
backend/venv/Lib/site-packages/joblib/test/test_module.py
Normal file
55
backend/venv/Lib/site-packages/joblib/test/test_module.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import sys
|
||||
|
||||
import joblib
|
||||
from joblib.test.common import with_multiprocessing
|
||||
from joblib.testing import check_subprocess_call
|
||||
|
||||
|
||||
def test_version():
|
||||
assert hasattr(joblib, "__version__"), (
|
||||
"There are no __version__ argument on the joblib module"
|
||||
)
|
||||
|
||||
|
||||
@with_multiprocessing
|
||||
def test_no_start_method_side_effect_on_import():
|
||||
# check that importing joblib does not implicitly set the global
|
||||
# start_method for multiprocessing.
|
||||
code = """if True:
|
||||
import joblib
|
||||
import multiprocessing as mp
|
||||
# The following line would raise RuntimeError if the
|
||||
# start_method is already set.
|
||||
mp.set_start_method("loky")
|
||||
"""
|
||||
check_subprocess_call([sys.executable, "-c", code])
|
||||
|
||||
|
||||
@with_multiprocessing
|
||||
def test_no_semaphore_tracker_on_import():
|
||||
# check that importing joblib does not implicitly spawn a resource tracker
|
||||
# or a semaphore tracker
|
||||
code = """if True:
|
||||
import joblib
|
||||
from multiprocessing import semaphore_tracker
|
||||
# The following line would raise RuntimeError if the
|
||||
# start_method is already set.
|
||||
msg = "multiprocessing.semaphore_tracker has been spawned on import"
|
||||
assert semaphore_tracker._semaphore_tracker._fd is None, msg"""
|
||||
if sys.version_info >= (3, 8):
|
||||
# semaphore_tracker was renamed in Python 3.8:
|
||||
code = code.replace("semaphore_tracker", "resource_tracker")
|
||||
check_subprocess_call([sys.executable, "-c", code])
|
||||
|
||||
|
||||
@with_multiprocessing
|
||||
def test_no_resource_tracker_on_import():
|
||||
code = """if True:
|
||||
import joblib
|
||||
from joblib.externals.loky.backend import resource_tracker
|
||||
# The following line would raise RuntimeError if the
|
||||
# start_method is already set.
|
||||
msg = "loky.resource_tracker has been spawned on import"
|
||||
assert resource_tracker._resource_tracker._fd is None, msg
|
||||
"""
|
||||
check_subprocess_call([sys.executable, "-c", code])
|
||||
1225
backend/venv/Lib/site-packages/joblib/test/test_numpy_pickle.py
Normal file
1225
backend/venv/Lib/site-packages/joblib/test/test_numpy_pickle.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,16 @@
|
||||
"""Test the old numpy pickler, compatibility version."""
|
||||
|
||||
# numpy_pickle is not a drop-in replacement of pickle, as it takes
|
||||
# filenames instead of open files as arguments.
|
||||
from joblib import numpy_pickle_compat
|
||||
|
||||
|
||||
def test_z_file(tmpdir):
|
||||
# Test saving and loading data with Zfiles.
|
||||
filename = tmpdir.join("test.pkl").strpath
|
||||
data = numpy_pickle_compat.asbytes("Foo, \n Bar, baz, \n\nfoobar")
|
||||
with open(filename, "wb") as f:
|
||||
numpy_pickle_compat.write_zfile(f, data)
|
||||
with open(filename, "rb") as f:
|
||||
data_read = numpy_pickle_compat.read_zfile(f)
|
||||
assert data == data_read
|
||||
@@ -0,0 +1,9 @@
|
||||
from joblib.compressor import BinaryZlibFile
|
||||
from joblib.testing import parametrize
|
||||
|
||||
|
||||
@parametrize("filename", ["test", "test"]) # testing str and unicode names
|
||||
def test_binary_zlib_file(tmpdir, filename):
|
||||
"""Testing creation of files depending on the type of the filenames."""
|
||||
binary_file = BinaryZlibFile(tmpdir.join(filename).strpath, mode="wb")
|
||||
binary_file.close()
|
||||
2250
backend/venv/Lib/site-packages/joblib/test/test_parallel.py
Normal file
2250
backend/venv/Lib/site-packages/joblib/test/test_parallel.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,94 @@
|
||||
try:
|
||||
# Python 2.7: use the C pickle to speed up
|
||||
# test_concurrency_safe_write which pickles big python objects
|
||||
import cPickle as cpickle
|
||||
except ImportError:
|
||||
import pickle as cpickle
|
||||
import functools
|
||||
import time
|
||||
from pickle import PicklingError
|
||||
|
||||
import pytest
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
from joblib._store_backends import (
|
||||
CacheWarning,
|
||||
FileSystemStoreBackend,
|
||||
concurrency_safe_write,
|
||||
)
|
||||
from joblib.backports import concurrency_safe_rename
|
||||
from joblib.test.common import with_multiprocessing
|
||||
from joblib.testing import parametrize, timeout
|
||||
|
||||
|
||||
def write_func(output, filename):
|
||||
with open(filename, "wb") as f:
|
||||
cpickle.dump(output, f)
|
||||
|
||||
|
||||
def load_func(expected, filename):
|
||||
for i in range(10):
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
reloaded = cpickle.load(f)
|
||||
break
|
||||
except (OSError, IOError):
|
||||
# On Windows you can have WindowsError ([Error 5] Access
|
||||
# is denied or [Error 13] Permission denied) when reading the file,
|
||||
# probably because a writer process has a lock on the file
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise
|
||||
assert expected == reloaded
|
||||
|
||||
|
||||
def concurrency_safe_write_rename(to_write, filename, write_func):
|
||||
temporary_filename = concurrency_safe_write(to_write, filename, write_func)
|
||||
concurrency_safe_rename(temporary_filename, filename)
|
||||
|
||||
|
||||
@timeout(0) # No timeout as this test can be long
|
||||
@with_multiprocessing
|
||||
@parametrize("backend", ["multiprocessing", "loky", "threading"])
|
||||
def test_concurrency_safe_write(tmpdir, backend):
|
||||
# Add one item to cache
|
||||
filename = tmpdir.join("test.pkl").strpath
|
||||
|
||||
obj = {str(i): i for i in range(int(1e5))}
|
||||
funcs = [
|
||||
functools.partial(concurrency_safe_write_rename, write_func=write_func)
|
||||
if i % 3 != 2
|
||||
else load_func
|
||||
for i in range(12)
|
||||
]
|
||||
Parallel(n_jobs=2, backend=backend)(delayed(func)(obj, filename) for func in funcs)
|
||||
|
||||
|
||||
def test_warning_on_dump_failure(tmpdir):
|
||||
# Check that a warning is raised when the dump fails for any reason but
|
||||
# a PicklingError.
|
||||
class UnpicklableObject(object):
|
||||
def __reduce__(self):
|
||||
raise RuntimeError("some exception")
|
||||
|
||||
backend = FileSystemStoreBackend()
|
||||
backend.location = tmpdir.join("test_warning_on_pickling_error").strpath
|
||||
backend.compress = None
|
||||
|
||||
with pytest.warns(CacheWarning, match="some exception"):
|
||||
backend.dump_item("testpath", UnpicklableObject())
|
||||
|
||||
|
||||
def test_warning_on_pickling_error(tmpdir):
|
||||
# This is separate from test_warning_on_dump_failure because in the
|
||||
# future we will turn this into an exception.
|
||||
class UnpicklableObject(object):
|
||||
def __reduce__(self):
|
||||
raise PicklingError("not picklable")
|
||||
|
||||
backend = FileSystemStoreBackend()
|
||||
backend.location = tmpdir.join("test_warning_on_pickling_error").strpath
|
||||
backend.compress = None
|
||||
|
||||
with pytest.warns(FutureWarning, match="not picklable"):
|
||||
backend.dump_item("testpath", UnpicklableObject())
|
||||
87
backend/venv/Lib/site-packages/joblib/test/test_testing.py
Normal file
87
backend/venv/Lib/site-packages/joblib/test/test_testing.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import re
|
||||
import sys
|
||||
|
||||
from joblib.testing import check_subprocess_call, raises
|
||||
|
||||
|
||||
def test_check_subprocess_call():
|
||||
code = "\n".join(
|
||||
["result = 1 + 2 * 3", "print(result)", "my_list = [1, 2, 3]", "print(my_list)"]
|
||||
)
|
||||
|
||||
check_subprocess_call([sys.executable, "-c", code])
|
||||
|
||||
# Now checking stdout with a regex
|
||||
check_subprocess_call(
|
||||
[sys.executable, "-c", code],
|
||||
# Regex needed for platform-specific line endings
|
||||
stdout_regex=r"7\s{1,2}\[1, 2, 3\]",
|
||||
)
|
||||
|
||||
|
||||
def test_check_subprocess_call_non_matching_regex():
|
||||
code = "42"
|
||||
non_matching_pattern = "_no_way_this_matches_anything_"
|
||||
|
||||
with raises(ValueError) as excinfo:
|
||||
check_subprocess_call(
|
||||
[sys.executable, "-c", code], stdout_regex=non_matching_pattern
|
||||
)
|
||||
excinfo.match("Unexpected stdout.+{}".format(non_matching_pattern))
|
||||
|
||||
|
||||
def test_check_subprocess_call_wrong_command():
|
||||
wrong_command = "_a_command_that_does_not_exist_"
|
||||
with raises(OSError):
|
||||
check_subprocess_call([wrong_command])
|
||||
|
||||
|
||||
def test_check_subprocess_call_non_zero_return_code():
|
||||
code_with_non_zero_exit = "\n".join(
|
||||
[
|
||||
"import sys",
|
||||
'print("writing on stdout")',
|
||||
'sys.stderr.write("writing on stderr")',
|
||||
"sys.exit(123)",
|
||||
]
|
||||
)
|
||||
|
||||
pattern = re.compile(
|
||||
"Non-zero return code: 123.+"
|
||||
"Stdout:\nwriting on stdout.+"
|
||||
"Stderr:\nwriting on stderr",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
with raises(ValueError) as excinfo:
|
||||
check_subprocess_call([sys.executable, "-c", code_with_non_zero_exit])
|
||||
excinfo.match(pattern)
|
||||
|
||||
|
||||
def test_check_subprocess_call_timeout():
|
||||
code_timing_out = "\n".join(
|
||||
[
|
||||
"import time",
|
||||
"import sys",
|
||||
'print("before sleep on stdout")',
|
||||
"sys.stdout.flush()",
|
||||
'sys.stderr.write("before sleep on stderr")',
|
||||
"sys.stderr.flush()",
|
||||
# We need to sleep for at least 2 * timeout seconds in case the SIGKILL
|
||||
# is triggered.
|
||||
"time.sleep(10)",
|
||||
'print("process should have be killed before")',
|
||||
"sys.stdout.flush()",
|
||||
]
|
||||
)
|
||||
|
||||
pattern = re.compile(
|
||||
"Non-zero return code:.+"
|
||||
"Stdout:\nbefore sleep on stdout\\s+"
|
||||
"Stderr:\nbefore sleep on stderr",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
with raises(ValueError) as excinfo:
|
||||
check_subprocess_call([sys.executable, "-c", code_timing_out], timeout=1)
|
||||
excinfo.match(pattern)
|
||||
45
backend/venv/Lib/site-packages/joblib/test/test_utils.py
Normal file
45
backend/venv/Lib/site-packages/joblib/test/test_utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
|
||||
from joblib._utils import eval_expr
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expr",
|
||||
[
|
||||
"exec('import os')",
|
||||
"print(1)",
|
||||
"import os",
|
||||
"1+1; import os",
|
||||
"1^1",
|
||||
"' ' * 10**10",
|
||||
"9. ** 10000.",
|
||||
],
|
||||
)
|
||||
def test_eval_expr_invalid(expr):
|
||||
with pytest.raises(ValueError, match="is not a valid or supported arithmetic"):
|
||||
eval_expr(expr)
|
||||
|
||||
|
||||
def test_eval_expr_too_long():
|
||||
expr = "1" + "+1" * 50
|
||||
with pytest.raises(ValueError, match="is too long"):
|
||||
eval_expr(expr)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr", ["1e7", "10**7", "9**9**9"])
|
||||
def test_eval_expr_too_large_literal(expr):
|
||||
with pytest.raises(ValueError, match="Numeric literal .* is too large"):
|
||||
eval_expr(expr)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expr, result",
|
||||
[
|
||||
("2*6", 12),
|
||||
("2**6", 64),
|
||||
("1 + 2*3**(4) / (6 + -7)", -161.0),
|
||||
("(20 // 3) % 5", 1),
|
||||
],
|
||||
)
|
||||
def test_eval_expr_valid(expr, result):
|
||||
assert eval_expr(expr) == result
|
||||
9
backend/venv/Lib/site-packages/joblib/test/testutils.py
Normal file
9
backend/venv/Lib/site-packages/joblib/test/testutils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
def return_slice_of_data(arr, start_idx, end_idx):
|
||||
return arr[start_idx:end_idx]
|
||||
|
||||
|
||||
def print_filename_and_raise(arr):
|
||||
from joblib._memmapping_reducer import _get_backing_memmap
|
||||
|
||||
print(_get_backing_memmap(arr).filename)
|
||||
raise ValueError
|
||||
Reference in New Issue
Block a user