872 lines
35 KiB
Python
872 lines
35 KiB
Python
import pytest
|
|
|
|
import numpy as np
|
|
|
|
from scipy.optimize._bracket import _ELIMITS
|
|
from scipy.optimize.elementwise import bracket_root, bracket_minimum
|
|
import scipy._lib._elementwise_iterative_method as eim
|
|
from scipy import stats
|
|
from scipy._lib._array_api_no_0d import (xp_assert_close, xp_assert_equal,
|
|
xp_assert_less, array_namespace)
|
|
from scipy._lib._array_api import xp_ravel
|
|
from scipy.conftest import array_api_compatible
|
|
|
|
|
|
# These tests were originally written for the private `optimize._bracket`
|
|
# interfaces, but now we want the tests to check the behavior of the public
|
|
# `optimize.elementwise` interfaces. Therefore, rather than importing
|
|
# `_bracket_root`/`_bracket_minimum` from `_bracket.py`, we import
|
|
# `bracket_root`/`bracket_minimum` from `optimize.elementwise` and wrap those
|
|
# functions to conform to the private interface. This may look a little strange,
|
|
# since it effectively just inverts the interface transformation done within the
|
|
# `bracket_root`/`bracket_minimum` functions, but it allows us to run the original,
|
|
# unmodified tests on the public interfaces, simplifying the PR that adds
|
|
# the public interfaces. We'll refactor this when we want to @parametrize the
|
|
# tests over multiple `method`s.
|
|
def _bracket_root(*args, **kwargs):
|
|
res = bracket_root(*args, **kwargs)
|
|
res.xl, res.xr = res.bracket
|
|
res.fl, res.fr = res.f_bracket
|
|
del res.bracket
|
|
del res.f_bracket
|
|
return res
|
|
|
|
|
|
def _bracket_minimum(*args, **kwargs):
|
|
res = bracket_minimum(*args, **kwargs)
|
|
res.xl, res.xm, res.xr = res.bracket
|
|
res.fl, res.fm, res.fr = res.f_bracket
|
|
del res.bracket
|
|
del res.f_bracket
|
|
return res
|
|
|
|
|
|
array_api_strict_skip_reason = 'Array API does not support fancy indexing assignment.'
|
|
jax_skip_reason = 'JAX arrays do not support item assignment.'
|
|
|
|
@pytest.mark.skip_xp_backends('array_api_strict', reason=array_api_strict_skip_reason)
|
|
@pytest.mark.skip_xp_backends('jax.numpy', reason=jax_skip_reason)
|
|
@array_api_compatible
|
|
@pytest.mark.usefixtures("skip_xp_backends")
|
|
class TestBracketRoot:
|
|
@pytest.mark.parametrize("seed", (615655101, 3141866013, 238075752))
|
|
@pytest.mark.parametrize("use_xmin", (False, True))
|
|
@pytest.mark.parametrize("other_side", (False, True))
|
|
@pytest.mark.parametrize("fix_one_side", (False, True))
|
|
def test_nfev_expected(self, seed, use_xmin, other_side, fix_one_side, xp):
|
|
# Property-based test to confirm that _bracket_root is behaving as
|
|
# expected. The basic case is when root < a < b.
|
|
# The number of times bracket expands (per side) can be found by
|
|
# setting the expression for the left endpoint of the bracket to the
|
|
# root of f (x=0), solving for i, and rounding up. The corresponding
|
|
# lower and upper ends of the bracket are found by plugging this back
|
|
# into the expression for the ends of the bracket.
|
|
# `other_side=True` is the case that a < b < root
|
|
# Special cases like a < root < b are tested separately
|
|
rng = np.random.default_rng(seed)
|
|
xl0, d, factor = xp.asarray(rng.random(size=3) * [1e5, 10, 5])
|
|
factor = 1 + factor # factor must be greater than 1
|
|
xr0 = xl0 + d # xr0 must be greater than a in basic case
|
|
|
|
def f(x):
|
|
f.count += 1
|
|
return x # root is 0
|
|
|
|
if use_xmin:
|
|
xmin = xp.asarray(-rng.random())
|
|
n = xp.ceil(xp.log(-(xl0 - xmin) / xmin) / xp.log(factor))
|
|
l, u = xmin + (xl0 - xmin)*factor**-n, xmin + (xl0 - xmin)*factor**-(n - 1)
|
|
kwargs = dict(xl0=xl0, xr0=xr0, factor=factor, xmin=xmin)
|
|
else:
|
|
n = xp.ceil(xp.log(xr0/d) / xp.log(factor))
|
|
l, u = xr0 - d*factor**n, xr0 - d*factor**(n-1)
|
|
kwargs = dict(xl0=xl0, xr0=xr0, factor=factor)
|
|
|
|
if other_side:
|
|
kwargs['xl0'], kwargs['xr0'] = -kwargs['xr0'], -kwargs['xl0']
|
|
l, u = -u, -l
|
|
if 'xmin' in kwargs:
|
|
kwargs['xmax'] = -kwargs.pop('xmin')
|
|
|
|
if fix_one_side:
|
|
if other_side:
|
|
kwargs['xmin'] = -xr0
|
|
else:
|
|
kwargs['xmax'] = xr0
|
|
|
|
f.count = 0
|
|
res = _bracket_root(f, **kwargs)
|
|
|
|
# Compare reported number of function evaluations `nfev` against
|
|
# reported `nit`, actual function call count `f.count`, and theoretical
|
|
# number of expansions `n`.
|
|
# When both sides are free, these get multiplied by 2 because function
|
|
# is evaluated on the left and the right each iteration.
|
|
# When one side is fixed, however, we add one: on the right side, the
|
|
# function gets evaluated once at b.
|
|
# Add 1 to `n` and `res.nit` because function evaluations occur at
|
|
# iterations *0*, 1, ..., `n`. Subtract 1 from `f.count` because
|
|
# function is called separately for left and right in iteration 0.
|
|
if not fix_one_side:
|
|
assert res.nfev == 2*(res.nit+1) == 2*(f.count-1) == 2*(n + 1)
|
|
else:
|
|
assert res.nfev == (res.nit+1)+1 == (f.count-1)+1 == (n+1)+1
|
|
|
|
# Compare reported bracket to theoretical bracket and reported function
|
|
# values to function evaluated at bracket.
|
|
bracket = xp.asarray([res.xl, res.xr])
|
|
xp_assert_close(bracket, xp.asarray([l, u]))
|
|
f_bracket = xp.asarray([res.fl, res.fr])
|
|
xp_assert_close(f_bracket, f(bracket))
|
|
|
|
# Check that bracket is valid and that status and success are correct
|
|
assert res.xr > res.xl
|
|
signs = xp.sign(f_bracket)
|
|
assert signs[0] == -signs[1]
|
|
assert res.status == 0
|
|
assert res.success
|
|
|
|
def f(self, q, p):
|
|
return stats._stats_py._SimpleNormal().cdf(q) - p
|
|
|
|
@pytest.mark.parametrize('p', [0.6, np.linspace(0.05, 0.95, 10)])
|
|
@pytest.mark.parametrize('xmin', [-5, None])
|
|
@pytest.mark.parametrize('xmax', [5, None])
|
|
@pytest.mark.parametrize('factor', [1.2, 2])
|
|
def test_basic(self, p, xmin, xmax, factor, xp):
|
|
# Test basic functionality to bracket root (distribution PPF)
|
|
res = _bracket_root(self.f, xp.asarray(-0.01), 0.01, xmin=xmin, xmax=xmax,
|
|
factor=factor, args=(xp.asarray(p),))
|
|
xp_assert_equal(-xp.sign(res.fl), xp.sign(res.fr))
|
|
|
|
@pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
|
|
def test_vectorization(self, shape, xp):
|
|
# Test for correct functionality, output shapes, and dtypes for various
|
|
# input shapes.
|
|
p = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else np.float64(0.6)
|
|
args = (p,)
|
|
maxiter = 10
|
|
|
|
@np.vectorize
|
|
def bracket_root_single(xl0, xr0, xmin, xmax, factor, p):
|
|
return _bracket_root(self.f, xl0, xr0, xmin=xmin, xmax=xmax,
|
|
factor=factor, args=(p,),
|
|
maxiter=maxiter)
|
|
|
|
def f(*args, **kwargs):
|
|
f.f_evals += 1
|
|
return self.f(*args, **kwargs)
|
|
f.f_evals = 0
|
|
|
|
rng = np.random.default_rng(2348234)
|
|
xl0 = -rng.random(size=shape)
|
|
xr0 = rng.random(size=shape)
|
|
xmin, xmax = 1e3*xl0, 1e3*xr0
|
|
if shape: # make some elements un
|
|
i = rng.random(size=shape) > 0.5
|
|
xmin[i], xmax[i] = -np.inf, np.inf
|
|
factor = rng.random(size=shape) + 1.5
|
|
refs = bracket_root_single(xl0, xr0, xmin, xmax, factor, p).ravel()
|
|
xl0, xr0, xmin, xmax, factor = (xp.asarray(xl0), xp.asarray(xr0),
|
|
xp.asarray(xmin), xp.asarray(xmax),
|
|
xp.asarray(factor))
|
|
args = tuple(map(xp.asarray, args))
|
|
res = _bracket_root(f, xl0, xr0, xmin=xmin, xmax=xmax, factor=factor,
|
|
args=args, maxiter=maxiter)
|
|
|
|
attrs = ['xl', 'xr', 'fl', 'fr', 'success', 'nfev', 'nit']
|
|
for attr in attrs:
|
|
ref_attr = [xp.asarray(getattr(ref, attr)) for ref in refs]
|
|
res_attr = getattr(res, attr)
|
|
xp_assert_close(xp_ravel(res_attr, xp=xp), xp.stack(ref_attr))
|
|
xp_assert_equal(res_attr.shape, shape)
|
|
|
|
xp_test = array_namespace(xp.asarray(1.))
|
|
assert res.success.dtype == xp_test.bool
|
|
if shape:
|
|
assert xp.all(res.success[1:-1])
|
|
assert res.status.dtype == xp.int32
|
|
assert res.nfev.dtype == xp.int32
|
|
assert res.nit.dtype == xp.int32
|
|
assert xp.max(res.nit) == f.f_evals - 2
|
|
xp_assert_less(res.xl, res.xr)
|
|
xp_assert_close(res.fl, xp.asarray(self.f(res.xl, *args)))
|
|
xp_assert_close(res.fr, xp.asarray(self.f(res.xr, *args)))
|
|
|
|
def test_flags(self, xp):
|
|
# Test cases that should produce different status flags; show that all
|
|
# can be produced simultaneously.
|
|
def f(xs, js):
|
|
funcs = [lambda x: x - 1.5,
|
|
lambda x: x - 1000,
|
|
lambda x: x - 1000,
|
|
lambda x: x * xp.nan,
|
|
lambda x: x]
|
|
|
|
return [funcs[int(j)](x) for x, j in zip(xs, js)]
|
|
|
|
args = (xp.arange(5, dtype=xp.int64),)
|
|
res = _bracket_root(f,
|
|
xl0=xp.asarray([-1., -1., -1., -1., 4.]),
|
|
xr0=xp.asarray([1, 1, 1, 1, -4]),
|
|
xmin=xp.asarray([-xp.inf, -1, -xp.inf, -xp.inf, 6]),
|
|
xmax=xp.asarray([xp.inf, 1, xp.inf, xp.inf, 2]),
|
|
args=args, maxiter=3)
|
|
|
|
ref_flags = xp.asarray([eim._ECONVERGED,
|
|
_ELIMITS,
|
|
eim._ECONVERR,
|
|
eim._EVALUEERR,
|
|
eim._EINPUTERR],
|
|
dtype=xp.int32)
|
|
|
|
xp_assert_equal(res.status, ref_flags)
|
|
|
|
@pytest.mark.parametrize("root", (0.622, [0.622, 0.623]))
|
|
@pytest.mark.parametrize('xmin', [-5, None])
|
|
@pytest.mark.parametrize('xmax', [5, None])
|
|
@pytest.mark.parametrize("dtype", ("float16", "float32", "float64"))
|
|
def test_dtype(self, root, xmin, xmax, dtype, xp):
|
|
# Test that dtypes are preserved
|
|
dtype = getattr(xp, dtype)
|
|
xp_test = array_namespace(xp.asarray(1.))
|
|
|
|
xmin = xmin if xmin is None else xp.asarray(xmin, dtype=dtype)
|
|
xmax = xmax if xmax is None else xp.asarray(xmax, dtype=dtype)
|
|
root = xp.asarray(root, dtype=dtype)
|
|
def f(x, root):
|
|
return xp_test.astype((x - root) ** 3, dtype)
|
|
|
|
bracket = xp.asarray([-0.01, 0.01], dtype=dtype)
|
|
res = _bracket_root(f, *bracket, xmin=xmin, xmax=xmax, args=(root,))
|
|
assert xp.all(res.success)
|
|
assert res.xl.dtype == res.xr.dtype == dtype
|
|
assert res.fl.dtype == res.fr.dtype == dtype
|
|
|
|
def test_input_validation(self, xp):
|
|
# Test input validation for appropriate error messages
|
|
|
|
message = '`func` must be callable.'
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_root(None, -4, 4)
|
|
|
|
message = '...must be numeric and real.'
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_root(lambda x: x, -4+1j, 4)
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_root(lambda x: x, -4, 'hello')
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_root(lambda x: x, -4, 4, xmin=np)
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_root(lambda x: x, -4, 4, xmax=object())
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_root(lambda x: x, -4, 4, factor=sum)
|
|
|
|
message = "All elements of `factor` must be greater than 1."
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_root(lambda x: x, -4, 4, factor=0.5)
|
|
|
|
message = "broadcast"
|
|
# raised by `xp.broadcast, but the traceback is readable IMO
|
|
with pytest.raises(Exception, match=message):
|
|
_bracket_root(lambda x: x, xp.asarray([-2, -3]), xp.asarray([3, 4, 5]))
|
|
# Consider making this give a more readable error message
|
|
# with pytest.raises(ValueError, match=message):
|
|
# _bracket_root(lambda x: [x[0], x[1], x[1]], [-3, -3], [5, 5])
|
|
|
|
message = '`maxiter` must be a non-negative integer.'
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_root(lambda x: x, -4, 4, maxiter=1.5)
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_root(lambda x: x, -4, 4, maxiter=-1)
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_root(lambda x: x, -4, 4, maxiter="shrubbery")
|
|
|
|
def test_special_cases(self, xp):
|
|
# Test edge cases and other special cases
|
|
xp_test = array_namespace(xp.asarray(1.))
|
|
|
|
# Test that integers are not passed to `f`
|
|
# (otherwise this would overflow)
|
|
def f(x):
|
|
assert xp_test.isdtype(x.dtype, "real floating")
|
|
return x ** 99 - 1
|
|
|
|
res = _bracket_root(f, xp.asarray(-7.), xp.asarray(5.))
|
|
assert res.success
|
|
|
|
# Test maxiter = 0. Should do nothing to bracket.
|
|
def f(x):
|
|
return x - 10
|
|
|
|
bracket = (xp.asarray(-3.), xp.asarray(5.))
|
|
res = _bracket_root(f, *bracket, maxiter=0)
|
|
assert res.xl, res.xr == bracket
|
|
assert res.nit == 0
|
|
assert res.nfev == 2
|
|
assert res.status == -2
|
|
|
|
# Test scalar `args` (not in tuple)
|
|
def f(x, c):
|
|
return c*x - 1
|
|
|
|
res = _bracket_root(f, xp.asarray(-1.), xp.asarray(1.),
|
|
args=xp.asarray(3.))
|
|
assert res.success
|
|
xp_assert_close(res.fl, f(res.xl, 3))
|
|
|
|
# Test other edge cases
|
|
|
|
def f(x):
|
|
f.count += 1
|
|
return x
|
|
|
|
# 1. root lies within guess of bracket
|
|
f.count = 0
|
|
_bracket_root(f, xp.asarray(-10), xp.asarray(20))
|
|
assert f.count == 2
|
|
|
|
# 2. bracket endpoint hits root exactly
|
|
f.count = 0
|
|
res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.),
|
|
factor=2)
|
|
|
|
assert res.nfev == 4
|
|
xp_assert_close(res.xl, xp.asarray(0.), atol=1e-15)
|
|
xp_assert_close(res.xr, xp.asarray(5.), atol=1e-15)
|
|
|
|
# 3. bracket limit hits root exactly
|
|
with np.errstate(over='ignore'):
|
|
res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.),
|
|
xmin=0)
|
|
xp_assert_close(res.xl, xp.asarray(0.), atol=1e-15)
|
|
|
|
with np.errstate(over='ignore'):
|
|
res = _bracket_root(f, xp.asarray(-10.), xp.asarray(-5.),
|
|
xmax=0)
|
|
xp_assert_close(res.xr, xp.asarray(0.), atol=1e-15)
|
|
|
|
# 4. bracket not within min, max
|
|
with np.errstate(over='ignore'):
|
|
res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.),
|
|
xmin=1)
|
|
assert not res.success
|
|
|
|
|
|
@pytest.mark.skip_xp_backends('array_api_strict', reason=array_api_strict_skip_reason)
|
|
@pytest.mark.skip_xp_backends('jax.numpy', reason=jax_skip_reason)
|
|
@array_api_compatible
|
|
@pytest.mark.usefixtures("skip_xp_backends")
|
|
class TestBracketMinimum:
|
|
def init_f(self):
|
|
def f(x, a, b):
|
|
f.count += 1
|
|
return (x - a)**2 + b
|
|
f.count = 0
|
|
return f
|
|
|
|
def assert_valid_bracket(self, result, xp):
|
|
assert xp.all(
|
|
(result.xl < result.xm) & (result.xm < result.xr)
|
|
)
|
|
assert xp.all(
|
|
(result.fl >= result.fm) & (result.fr > result.fm)
|
|
| (result.fl > result.fm) & (result.fr > result.fm)
|
|
)
|
|
|
|
def get_kwargs(
|
|
self, *, xl0=None, xr0=None, factor=None, xmin=None, xmax=None, args=None
|
|
):
|
|
names = ("xl0", "xr0", "xmin", "xmax", "factor", "args")
|
|
return {
|
|
name: val for name, val in zip(names, (xl0, xr0, xmin, xmax, factor, args))
|
|
if val is not None
|
|
}
|
|
|
|
@pytest.mark.parametrize(
|
|
"seed",
|
|
(
|
|
307448016549685229886351382450158984917,
|
|
11650702770735516532954347931959000479,
|
|
113767103358505514764278732330028568336,
|
|
)
|
|
)
|
|
@pytest.mark.parametrize("use_xmin", (False, True))
|
|
@pytest.mark.parametrize("other_side", (False, True))
|
|
def test_nfev_expected(self, seed, use_xmin, other_side, xp):
|
|
rng = np.random.default_rng(seed)
|
|
args = (xp.asarray(0.), xp.asarray(0.)) # f(x) = x^2 with minimum at 0
|
|
# xl0, xm0, xr0 are chosen such that the initial bracket is to
|
|
# the right of the minimum, and the bracket will expand
|
|
# downhill towards zero.
|
|
xl0, d1, d2, factor = xp.asarray(rng.random(size=4) * [1e5, 10, 10, 5])
|
|
xm0 = xl0 + d1
|
|
xr0 = xm0 + d2
|
|
# Factor should be greater than one.
|
|
factor += 1
|
|
|
|
if use_xmin:
|
|
xmin = xp.asarray(-rng.random() * 5, dtype=xp.float64)
|
|
n = int(xp.ceil(xp.log(-(xl0 - xmin) / xmin) / xp.log(factor)))
|
|
lower = xmin + (xl0 - xmin)*factor**-n
|
|
middle = xmin + (xl0 - xmin)*factor**-(n-1)
|
|
upper = xmin + (xl0 - xmin)*factor**-(n-2) if n > 1 else xm0
|
|
# It may be the case the lower is below the minimum, but we still
|
|
# don't have a valid bracket.
|
|
if middle**2 > lower**2:
|
|
n += 1
|
|
lower, middle, upper = (
|
|
xmin + (xl0 - xmin)*factor**-n, lower, middle
|
|
)
|
|
else:
|
|
xmin = None
|
|
n = int(xp.ceil(xp.log(xl0 / d1) / xp.log(factor)))
|
|
lower = xl0 - d1*factor**n
|
|
middle = xl0 - d1*factor**(n-1) if n > 1 else xl0
|
|
upper = xl0 - d1*factor**(n-2) if n > 1 else xm0
|
|
# It may be the case the lower is below the minimum, but we still
|
|
# don't have a valid bracket.
|
|
if middle**2 > lower**2:
|
|
n += 1
|
|
lower, middle, upper = (
|
|
xl0 - d1*factor**n, lower, middle
|
|
)
|
|
f = self.init_f()
|
|
|
|
xmax = None
|
|
if other_side:
|
|
xl0, xm0, xr0 = -xr0, -xm0, -xl0
|
|
xmin, xmax = None, -xmin if xmin is not None else None
|
|
lower, middle, upper = -upper, -middle, -lower
|
|
|
|
kwargs = self.get_kwargs(
|
|
xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, factor=factor, args=args
|
|
)
|
|
result = _bracket_minimum(f, xp.asarray(xm0), **kwargs)
|
|
|
|
# Check that `nfev` and `nit` have the correct relationship
|
|
assert result.nfev == result.nit + 3
|
|
# Check that `nfev` reports the correct number of function evaluations.
|
|
assert result.nfev == f.count
|
|
# Check that the number of iterations matches the theoretical value.
|
|
assert result.nit == n
|
|
|
|
# Compare reported bracket to theoretical bracket and reported function
|
|
# values to function evaluated at bracket.
|
|
xp_assert_close(result.xl, lower)
|
|
xp_assert_close(result.xm, middle)
|
|
xp_assert_close(result.xr, upper)
|
|
xp_assert_close(result.fl, f(lower, *args))
|
|
xp_assert_close(result.fm, f(middle, *args))
|
|
xp_assert_close(result.fr, f(upper, *args))
|
|
|
|
self.assert_valid_bracket(result, xp)
|
|
assert result.status == 0
|
|
assert result.success
|
|
|
|
def test_flags(self, xp):
|
|
# Test cases that should produce different status flags; show that all
|
|
# can be produced simultaneously
|
|
def f(xs, js):
|
|
funcs = [lambda x: (x - 1.5)**2,
|
|
lambda x: x,
|
|
lambda x: x,
|
|
lambda x: xp.nan,
|
|
lambda x: x**2]
|
|
|
|
return [funcs[j](x) for x, j in zip(xs, js)]
|
|
|
|
args = (xp.arange(5, dtype=xp.int64),)
|
|
xl0 = xp.asarray([-1.0, -1.0, -1.0, -1.0, 6.0])
|
|
xm0 = xp.asarray([0.0, 0.0, 0.0, 0.0, 4.0])
|
|
xr0 = xp.asarray([1.0, 1.0, 1.0, 1.0, 2.0])
|
|
xmin = xp.asarray([-xp.inf, -1.0, -xp.inf, -xp.inf, 8.0])
|
|
|
|
result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, xmin=xmin,
|
|
args=args, maxiter=3)
|
|
|
|
reference_flags = xp.asarray([eim._ECONVERGED, _ELIMITS,
|
|
eim._ECONVERR, eim._EVALUEERR,
|
|
eim._EINPUTERR], dtype=xp.int32)
|
|
xp_assert_equal(result.status, reference_flags)
|
|
|
|
@pytest.mark.parametrize("minimum", (0.622, [0.622, 0.623]))
|
|
@pytest.mark.parametrize("dtype", ("float16", "float32", "float64"))
|
|
@pytest.mark.parametrize("xmin", [-5, None])
|
|
@pytest.mark.parametrize("xmax", [5, None])
|
|
def test_dtypes(self, minimum, xmin, xmax, dtype, xp):
|
|
dtype = getattr(xp, dtype)
|
|
xp_test = array_namespace(xp.asarray(1.))
|
|
xmin = xmin if xmin is None else xp.asarray(xmin, dtype=dtype)
|
|
xmax = xmax if xmax is None else xp.asarray(xmax, dtype=dtype)
|
|
minimum = xp.asarray(minimum, dtype=dtype)
|
|
|
|
def f(x, minimum):
|
|
return xp_test.astype((x - minimum)**2, dtype)
|
|
|
|
xl0, xm0, xr0 = [-0.01, 0.0, 0.01]
|
|
result = _bracket_minimum(
|
|
f, xp.asarray(xm0, dtype=dtype), xl0=xp.asarray(xl0, dtype=dtype),
|
|
xr0=xp.asarray(xr0, dtype=dtype), xmin=xmin, xmax=xmax, args=(minimum, )
|
|
)
|
|
assert xp.all(result.success)
|
|
assert result.xl.dtype == result.xm.dtype == result.xr.dtype == dtype
|
|
assert result.fl.dtype == result.fm.dtype == result.fr.dtype == dtype
|
|
|
|
@pytest.mark.skip_xp_backends(np_only=True, reason="str/object arrays")
|
|
def test_input_validation(self, xp):
|
|
# Test input validation for appropriate error messages
|
|
|
|
message = '`func` must be callable.'
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(None, -4, xl0=4)
|
|
|
|
message = '...must be numeric and real.'
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x**2, xp.asarray(4+1j))
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xl0='hello')
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4),
|
|
xr0='farcical aquatic ceremony')
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xmin=np)
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xmax=object())
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), factor=sum)
|
|
|
|
message = "All elements of `factor` must be greater than 1."
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x, xp.asarray(-4), factor=0.5)
|
|
|
|
message = "shape mismatch: objects cannot be broadcast"
|
|
# raised by `xp.broadcast, but the traceback is readable IMO
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x**2, xp.asarray([-2, -3]), xl0=[-3, -4, -5])
|
|
|
|
message = '`maxiter` must be a non-negative integer.'
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xr0=4, maxiter=1.5)
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xr0=4, maxiter=-1)
|
|
with pytest.raises(ValueError, match=message):
|
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xr0=4, maxiter="ekki")
|
|
|
|
@pytest.mark.parametrize("xl0", [0.0, None])
|
|
@pytest.mark.parametrize("xm0", (0.05, 0.1, 0.15))
|
|
@pytest.mark.parametrize("xr0", (0.2, 0.4, 0.6, None))
|
|
# Minimum is ``a`` for each tuple ``(a, b)`` below. Tests cases where minimum
|
|
# is within, or at varying distances to the left or right of the initial
|
|
# bracket.
|
|
@pytest.mark.parametrize(
|
|
"args",
|
|
(
|
|
(1.2, 0), (-0.5, 0), (0.1, 0), (0.2, 0), (3.6, 0), (21.4, 0),
|
|
(121.6, 0), (5764.1, 0), (-6.4, 0), (-12.9, 0), (-146.2, 0)
|
|
)
|
|
)
|
|
def test_scalar_no_limits(self, xl0, xm0, xr0, args, xp):
|
|
f = self.init_f()
|
|
kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, args=tuple(map(xp.asarray, args)))
|
|
result = _bracket_minimum(f, xp.asarray(xm0, dtype=xp.float64), **kwargs)
|
|
self.assert_valid_bracket(result, xp)
|
|
assert result.status == 0
|
|
assert result.success
|
|
assert result.nfev == f.count
|
|
|
|
@pytest.mark.parametrize(
|
|
# xmin is set at 0.0 in all cases.
|
|
"xl0,xm0,xr0,xmin",
|
|
(
|
|
# Initial bracket at varying distances from the xmin.
|
|
(0.5, 0.75, 1.0, 0.0),
|
|
(1.0, 2.5, 4.0, 0.0),
|
|
(2.0, 4.0, 6.0, 0.0),
|
|
(12.0, 16.0, 20.0, 0.0),
|
|
# Test default initial left endpoint selection. It should not
|
|
# be below xmin.
|
|
(None, 0.75, 1.0, 0.0),
|
|
(None, 2.5, 4.0, 0.0),
|
|
(None, 4.0, 6.0, 0.0),
|
|
(None, 16.0, 20.0, 0.0),
|
|
)
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"args", (
|
|
(0.0, 0.0), # Minimum is directly at xmin.
|
|
(1e-300, 0.0), # Minimum is extremely close to xmin.
|
|
(1e-20, 0.0), # Minimum is very close to xmin.
|
|
# Minimum at varying distances from xmin.
|
|
(0.1, 0.0),
|
|
(0.2, 0.0),
|
|
(0.4, 0.0)
|
|
)
|
|
)
|
|
def test_scalar_with_limit_left(self, xl0, xm0, xr0, xmin, args, xp):
|
|
f = self.init_f()
|
|
kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, xmin=xmin,
|
|
args=tuple(map(xp.asarray, args)))
|
|
result = _bracket_minimum(f, xp.asarray(xm0), **kwargs)
|
|
self.assert_valid_bracket(result, xp)
|
|
assert result.status == 0
|
|
assert result.success
|
|
assert result.nfev == f.count
|
|
|
|
@pytest.mark.parametrize(
|
|
#xmax is set to 1.0 in all cases.
|
|
"xl0,xm0,xr0,xmax",
|
|
(
|
|
# Bracket at varying distances from xmax.
|
|
(0.2, 0.3, 0.4, 1.0),
|
|
(0.05, 0.075, 0.1, 1.0),
|
|
(-0.2, -0.1, 0.0, 1.0),
|
|
(-21.2, -17.7, -14.2, 1.0),
|
|
# Test default right endpoint selection. It should not exceed xmax.
|
|
(0.2, 0.3, None, 1.0),
|
|
(0.05, 0.075, None, 1.0),
|
|
(-0.2, -0.1, None, 1.0),
|
|
(-21.2, -17.7, None, 1.0),
|
|
)
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"args", (
|
|
(0.9999999999999999, 0.0), # Minimum very close to xmax.
|
|
# Minimum at varying distances from xmax.
|
|
(0.9, 0.0),
|
|
(0.7, 0.0),
|
|
(0.5, 0.0)
|
|
)
|
|
)
|
|
def test_scalar_with_limit_right(self, xl0, xm0, xr0, xmax, args, xp):
|
|
f = self.init_f()
|
|
args = tuple(xp.asarray(arg, dtype=xp.float64) for arg in args)
|
|
kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, xmax=xmax, args=args)
|
|
result = _bracket_minimum(f, xp.asarray(xm0, dtype=xp.float64), **kwargs)
|
|
self.assert_valid_bracket(result, xp)
|
|
assert result.status == 0
|
|
assert result.success
|
|
assert result.nfev == f.count
|
|
|
|
@pytest.mark.parametrize(
|
|
"xl0,xm0,xr0,xmin,xmax,args",
|
|
(
|
|
( # Case 1:
|
|
# Initial bracket.
|
|
0.2,
|
|
0.3,
|
|
0.4,
|
|
# Function slopes down to the right from the bracket to a minimum
|
|
# at 1.0. xmax is also at 1.0
|
|
None,
|
|
1.0,
|
|
(1.0, 0.0)
|
|
),
|
|
( # Case 2:
|
|
# Initial bracket.
|
|
1.4,
|
|
1.95,
|
|
2.5,
|
|
# Function slopes down to the left from the bracket to a minimum at
|
|
# 0.3 with xmin set to 0.3.
|
|
0.3,
|
|
None,
|
|
(0.3, 0.0)
|
|
),
|
|
(
|
|
# Case 3:
|
|
# Initial bracket.
|
|
2.6,
|
|
3.25,
|
|
3.9,
|
|
# Function slopes down and to the right to a minimum at 99.4 with xmax
|
|
# at 99.4. Tests case where minimum is at xmax relatively further from
|
|
# the bracket.
|
|
None,
|
|
99.4,
|
|
(99.4, 0)
|
|
),
|
|
(
|
|
# Case 4:
|
|
# Initial bracket.
|
|
4,
|
|
4.5,
|
|
5,
|
|
# Function slopes down and to the left away from the bracket with a
|
|
# minimum at -26.3 with xmin set to -26.3. Tests case where minimum is
|
|
# at xmin relatively far from the bracket.
|
|
-26.3,
|
|
None,
|
|
(-26.3, 0)
|
|
),
|
|
(
|
|
# Case 5:
|
|
# Similar to Case 1 above, but tests default values of xl0 and xr0.
|
|
None,
|
|
0.3,
|
|
None,
|
|
None,
|
|
1.0,
|
|
(1.0, 0.0)
|
|
),
|
|
( # Case 6:
|
|
# Similar to Case 2 above, but tests default values of xl0 and xr0.
|
|
None,
|
|
1.95,
|
|
None,
|
|
0.3,
|
|
None,
|
|
(0.3, 0.0)
|
|
),
|
|
(
|
|
# Case 7:
|
|
# Similar to Case 3 above, but tests default values of xl0 and xr0.
|
|
None,
|
|
3.25,
|
|
None,
|
|
None,
|
|
99.4,
|
|
(99.4, 0)
|
|
),
|
|
(
|
|
# Case 8:
|
|
# Similar to Case 4 above, but tests default values of xl0 and xr0.
|
|
None,
|
|
4.5,
|
|
None,
|
|
-26.3,
|
|
None,
|
|
(-26.3, 0)
|
|
),
|
|
)
|
|
)
|
|
def test_minimum_at_boundary_point(self, xl0, xm0, xr0, xmin, xmax, args, xp):
|
|
f = self.init_f()
|
|
kwargs = self.get_kwargs(xr0=xr0, xmin=xmin, xmax=xmax,
|
|
args=tuple(map(xp.asarray, args)))
|
|
result = _bracket_minimum(f, xp.asarray(xm0), **kwargs)
|
|
assert result.status == -1
|
|
assert args[0] in (result.xl, result.xr)
|
|
assert result.nfev == f.count
|
|
|
|
@pytest.mark.parametrize('shape', [tuple(), (12, ), (3, 4), (3, 2, 2)])
|
|
def test_vectorization(self, shape, xp):
|
|
# Test for correct functionality, output shapes, and dtypes for
|
|
# various input shapes.
|
|
a = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else 0.6
|
|
args = (a, 0.)
|
|
maxiter = 10
|
|
|
|
@np.vectorize
|
|
def bracket_minimum_single(xm0, xl0, xr0, xmin, xmax, factor, a):
|
|
return _bracket_minimum(self.init_f(), xm0, xl0=xl0, xr0=xr0, xmin=xmin,
|
|
xmax=xmax, factor=factor, maxiter=maxiter,
|
|
args=(a, 0.0))
|
|
|
|
f = self.init_f()
|
|
|
|
rng = np.random.default_rng(2348234)
|
|
xl0 = -rng.random(size=shape)
|
|
xr0 = rng.random(size=shape)
|
|
xm0 = xl0 + rng.random(size=shape) * (xr0 - xl0)
|
|
xmin, xmax = 1e3*xl0, 1e3*xr0
|
|
if shape: # make some elements un
|
|
i = rng.random(size=shape) > 0.5
|
|
xmin[i], xmax[i] = -np.inf, np.inf
|
|
factor = rng.random(size=shape) + 1.5
|
|
refs = bracket_minimum_single(xm0, xl0, xr0, xmin, xmax, factor, a).ravel()
|
|
args = tuple(xp.asarray(arg, dtype=xp.float64) for arg in args)
|
|
res = _bracket_minimum(f, xp.asarray(xm0), xl0=xl0, xr0=xr0, xmin=xmin,
|
|
xmax=xmax, factor=factor, args=args, maxiter=maxiter)
|
|
|
|
attrs = ['xl', 'xm', 'xr', 'fl', 'fm', 'fr', 'success', 'nfev', 'nit']
|
|
for attr in attrs:
|
|
ref_attr = [xp.asarray(getattr(ref, attr)) for ref in refs]
|
|
res_attr = getattr(res, attr)
|
|
xp_assert_close(xp_ravel(res_attr, xp=xp), xp.stack(ref_attr))
|
|
xp_assert_equal(res_attr.shape, shape)
|
|
|
|
xp_test = array_namespace(xp.asarray(1.))
|
|
assert res.success.dtype == xp_test.bool
|
|
if shape:
|
|
assert xp.all(res.success[1:-1])
|
|
assert res.status.dtype == xp.int32
|
|
assert res.nfev.dtype == xp.int32
|
|
assert res.nit.dtype == xp.int32
|
|
assert xp.max(res.nit) == f.count - 3
|
|
self.assert_valid_bracket(res, xp)
|
|
xp_assert_close(res.fl, f(res.xl, *args))
|
|
xp_assert_close(res.fm, f(res.xm, *args))
|
|
xp_assert_close(res.fr, f(res.xr, *args))
|
|
|
|
def test_special_cases(self, xp):
|
|
# Test edge cases and other special cases.
|
|
xp_test = array_namespace(xp.asarray(1.))
|
|
|
|
# Test that integers are not passed to `f`
|
|
# (otherwise this would overflow)
|
|
def f(x):
|
|
assert xp_test.isdtype(x.dtype, "numeric")
|
|
return x ** 98 - 1
|
|
|
|
result = _bracket_minimum(f, xp.asarray(-7., dtype=xp.float64), xr0=5)
|
|
assert result.success
|
|
|
|
# Test maxiter = 0. Should do nothing to bracket.
|
|
def f(x):
|
|
return x**2 - 10
|
|
|
|
xl0, xm0, xr0 = xp.asarray(-3.), xp.asarray(-1.), xp.asarray(2.)
|
|
result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, maxiter=0)
|
|
xp_assert_equal(result.xl, xl0)
|
|
xp_assert_equal(result.xm, xm0)
|
|
xp_assert_equal(result.xr, xr0)
|
|
|
|
# Test scalar `args` (not in tuple)
|
|
def f(x, c):
|
|
return c*x**2 - 1
|
|
|
|
result = _bracket_minimum(f, xp.asarray(-1.), args=xp.asarray(3.))
|
|
assert result.success
|
|
xp_assert_close(result.fl, f(result.xl, 3))
|
|
|
|
# Initial bracket is valid.
|
|
f = self.init_f()
|
|
xl0, xm0, xr0 = xp.asarray(-1.0), xp.asarray(-0.2), xp.asarray(1.0)
|
|
args = (xp.asarray(0.), xp.asarray(0.))
|
|
result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, args=args)
|
|
assert f.count == 3
|
|
|
|
xp_assert_equal(result.xl, xl0)
|
|
xp_assert_equal(result.xm , xm0)
|
|
xp_assert_equal(result.xr, xr0)
|
|
xp_assert_equal(result.fl, f(xl0, *args))
|
|
xp_assert_equal(result.fm, f(xm0, *args))
|
|
xp_assert_equal(result.fr, f(xr0, *args))
|
|
|
|
def test_gh_20562_left(self, xp):
|
|
# Regression test for https://github.com/scipy/scipy/issues/20562
|
|
# minimum of f in [xmin, xmax] is at xmin.
|
|
xmin, xmax = xp.asarray(0.21933608), xp.asarray(1.39713606)
|
|
|
|
def f(x):
|
|
log_a, log_b = xp.log(xmin), xp.log(xmax)
|
|
return -((log_b - log_a)*x)**-1
|
|
|
|
result = _bracket_minimum(f, xp.asarray(0.5535723499480897), xmin=xmin,
|
|
xmax=xmax)
|
|
assert xmin == result.xl
|
|
|
|
def test_gh_20562_right(self, xp):
|
|
# Regression test for https://github.com/scipy/scipy/issues/20562
|
|
# minimum of f in [xmin, xmax] is at xmax.
|
|
xmin, xmax = xp.asarray(-1.39713606), xp.asarray(-0.21933608)
|
|
|
|
def f(x):
|
|
log_a, log_b = xp.log(-xmax), xp.log(-xmin)
|
|
return ((log_b - log_a)*x)**-1
|
|
|
|
result = _bracket_minimum(f, xp.asarray(-0.5535723499480897),
|
|
xmin=xmin, xmax=xmax)
|
|
assert xmax == result.xr
|