Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
* Changed the build scripts and documentation due to `python setup.py develop` deprecation notice [#2716](https://github.com/IntelPython/dpnp/pull/2716)
* Clarified behavior on repeated `axes` in `dpnp.tensordot` and `dpnp.linalg.tensordot` functions [#2733](https://github.com/IntelPython/dpnp/pull/2733)
* Improved documentation of `file` argument in `dpnp.fromfile` [#2745](https://github.com/IntelPython/dpnp/pull/2745)
* Aligned `strides` property of `dpnp.ndarray` with NumPy and CuPy implementations [#2747](https://github.com/IntelPython/dpnp/pull/2747)

### Deprecated

Expand Down
87 changes: 66 additions & 21 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def __init__(
else:
buffer = usm_type

if strides is not None:
# dpctl expects strides as elements displacement in memory,
# while dpnp (and numpy as well) relies on bytes displacement
if dtype is None:
dtype = dpnp.default_float_type(
device=device, sycl_queue=sycl_queue
)
it_sz = dpnp.dtype(dtype).itemsize
strides = tuple(el // it_sz for el in strides)

sycl_queue_normalized = dpnp.get_normalized_queue_device(
device=device, sycl_queue=sycl_queue
)
Expand Down Expand Up @@ -1855,16 +1865,53 @@ def std(
@property
def strides(self):
"""
Return memory displacement in array elements, upon unit
change of respective index.
Tuple of bytes to step in each dimension when traversing an array.

For example, for strides ``(s1, s2, s3)`` and multi-index
``(i1, i2, i3)`` position of the respective element relative
to zero multi-index element is ``s1*s1 + s2*i2 + s3*i3``.
The byte offset of element ``(i[0], i[1], ..., i[n])`` in an array `a`
is::

"""
offset = sum(dpnp.array(i) * a.strides)

return self._array_obj.strides
For full documentation refer to :obj:`numpy.ndarray.strides`.

See Also
--------
:obj:`dpnp.lib.stride_tricks.as_strided` : Return a view into the array
with given shape and strides.

Examples
--------
>>> import dpnp as np
>>> y = np.reshape(np.arange(2 * 3 * 4, dtype=np.int32), (2, 3, 4))
>>> y
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]], dtype=np.int32)
>>> y.strides
(48, 16, 4)
>>> y[1, 1, 1]
array(17, dtype=int32)
>>> offset = sum(i * s for i, s in zip((1, 1, 1), y.strides))
>>> offset // y.itemsize
17

>>> x = np.reshape(np.arange(5*6*7*8, dtype=np.int32), (5, 6, 7, 8))
>>> x = x.transpose(2, 3, 1, 0)
>>> x.strides
(32, 4, 224, 1344)
>>> offset = sum(i * s for i, s in zip((3, 5, 2, 2), x.strides))
>>> x[3, 5, 2, 2]
array(813, dtype=int32)
>>> offset // x.itemsize
813

"""

it_sz = self.itemsize
return tuple(el * it_sz for el in self._array_obj.strides)

def sum(
self,
Expand Down Expand Up @@ -2335,23 +2382,20 @@ def view(self, /, dtype=None, *, type=None):

# resize on last axis only
axis = ndim - 1
if old_sh[axis] != 1 and self.size != 0 and old_strides[axis] != 1:
if (
old_sh[axis] != 1
and self.size != 0
and old_strides[axis] != old_itemsz
):
raise ValueError(
"To change to a dtype of a different size, "
"the last axis must be contiguous"
)

# normalize strides whenever itemsize changes
if old_itemsz > new_itemsz:
new_strides = list(
el * (old_itemsz // new_itemsz) for el in old_strides
)
else:
new_strides = list(
el // (new_itemsz // old_itemsz) for el in old_strides
)
new_strides[axis] = 1
new_strides = tuple(new_strides)
new_strides = tuple(
old_strides[i] if i != axis else new_itemsz for i in range(ndim)
)

new_dim = old_sh[axis] * old_itemsz
if new_dim % new_itemsz != 0:
Expand All @@ -2361,9 +2405,10 @@ def view(self, /, dtype=None, *, type=None):
)

# normalize shape whenever itemsize changes
new_sh = list(old_sh)
new_sh[axis] = new_dim // new_itemsz
new_sh = tuple(new_sh)
new_sh = tuple(
old_sh[i] if i != axis else new_dim // new_itemsz
for i in range(ndim)
)

return dpnp_array(
new_sh,
Expand Down
6 changes: 3 additions & 3 deletions dpnp/dpnp_iface_arraycreation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _get_empty_array(
elif a.flags.c_contiguous:
order = "C"
else:
strides = _get_strides_for_order_k(a, _shape)
strides = _get_strides_for_order_k(a, _dtype, shape=_shape)
order = "C"
elif order not in "cfCF":
raise ValueError(
Expand All @@ -122,15 +122,15 @@ def _get_empty_array(
)


def _get_strides_for_order_k(x, shape=None):
def _get_strides_for_order_k(x, dtype, shape=None):
"""
Calculate strides when order='K' for empty_like, ones_like, zeros_like,
and full_like where `shape` is ``None`` or len(shape) == x.ndim.

"""
stride_and_index = sorted([(abs(s), -i) for i, s in enumerate(x.strides)])
strides = [0] * x.ndim
stride = 1
stride = dpnp.dtype(dtype).itemsize
for _, i in stride_and_index:
strides[-i] = stride
stride *= shape[-i] if shape else x.shape[-i]
Expand Down
2 changes: 1 addition & 1 deletion dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
elif 0 < offset < m:
out_shape = a_shape[:-2] + (min(n, m - offset),)
out_strides = a_straides[:-2] + (st_n + st_m,)
out_offset = st_m * offset
out_offset = st_m // a.itemsize * offset
else:
out_shape = a_shape[:-2] + (0,)
out_strides = a_straides[:-2] + (1,)
Expand Down
2 changes: 1 addition & 1 deletion dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _define_contig_flag(x):
"""

flag = False
x_strides = x.strides
x_strides = dpnp.get_usm_ndarray(x).strides
x_shape = x.shape
if x.ndim < 2:
return True, True, True
Expand Down
14 changes: 10 additions & 4 deletions dpnp/fft/dpnp_utils_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,13 @@ def _compute_result(dsc, a, out, forward, c2c, out_strides):
)
result = a
else:
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
if (
out is not None
and out.strides == tuple(out_strides)
and not ti._array_overlap(a_usm, dpnp.get_usm_ndarray(out))
and out_usm.strides == tuple(out_strides)
and not ti._array_overlap(a_usm, out_usm)
):
res_usm = dpnp.get_usm_ndarray(out)
res_usm = out_usm
result = out
else:
# Result array that is used in oneMKL must have the exact same
Expand All @@ -223,6 +224,10 @@ def _compute_result(dsc, a, out, forward, c2c, out_strides):
if a.dtype == dpnp.complex64
else dpnp.float64
)
# cast to expected strides format
out_strides = tuple(
el * dpnp.dtype(out_dtype).itemsize for el in out_strides
)
result = dpnp_array(
out_shape,
dtype=out_dtype,
Expand Down Expand Up @@ -419,7 +424,8 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
if cufft_wa: # pragma: no cover
a = dpnp.moveaxis(a, -1, -2)

a_strides = _standardize_strides_to_nonzero(a.strides, a.shape)
strides = dpnp.get_usm_ndarray(a).strides
a_strides = _standardize_strides_to_nonzero(strides, a.shape)
dsc, out_strides = _commit_descriptor(
a, forward, in_place, c2c, a_strides, index, batch_fft
)
Expand Down
14 changes: 7 additions & 7 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _batched_inv(a, res_type):
_manager.add_event_pair(ht_ev, copy_ev)

ipiv_stride = n
a_stride = a_h.strides[0]
a_stride = a_h.strides[0] // a_h.itemsize

# Call the LAPACK extension function _getrf_batch
# to perform LU decomposition of a batch of general matrices
Expand Down Expand Up @@ -298,7 +298,7 @@ def _batched_lu_factor(a, res_type):
dev_info_h = [0] * batch_size

ipiv_stride = n
a_stride = a_h.strides[0]
a_stride = a_h.strides[0] // a_h.itemsize

# Call the LAPACK extension function _getrf_batch
# to perform LU decomposition of a batch of general matrices
Expand Down Expand Up @@ -471,8 +471,8 @@ def _batched_qr(a, mode="reduced"):
dtype=res_type,
)

a_stride = a_t.strides[0]
tau_stride = tau_h.strides[0]
a_stride = a_t.strides[0] // a_t.itemsize
tau_stride = tau_h.strides[0] // tau_h.itemsize

# Call the LAPACK extension function _geqrf_batch to compute
# the QR factorization of a general m x n matrix.
Expand Down Expand Up @@ -535,8 +535,8 @@ def _batched_qr(a, mode="reduced"):
)
_manager.add_event_pair(ht_ev, copy_ev)

q_stride = q.strides[0]
tau_stride = tau_h.strides[0]
q_stride = q.strides[0] // q.itemsize
tau_stride = tau_h.strides[0] // tau_h.itemsize

# Get LAPACK function (_orgqr_batch for real or _ungqf_batch for complex
# data types) for QR factorization
Expand Down Expand Up @@ -1818,7 +1818,7 @@ def dpnp_cholesky_batch(a, upper_lower, res_type):
)
_manager.add_event_pair(ht_ev, copy_ev)

a_stride = a_h.strides[0]
a_stride = a_h.strides[0] // a_h.itemsize

# Call the LAPACK extension function _potrf_batch
# to computes the Cholesky decomposition of a batch of
Expand Down
3 changes: 2 additions & 1 deletion dpnp/scipy/linalg/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

"""

# pylint: disable=duplicate-code
# pylint: disable=no-name-in-module
# pylint: disable=protected-access

Expand Down Expand Up @@ -144,7 +145,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
dev_info_h = [0] * batch_size

ipiv_stride = k
a_stride = a_h.strides[-1]
a_stride = a_h.strides[-1] // a_h.itemsize

# Call the LAPACK extension function _getrf_batch
# to perform LU decomposition of a batch of general matrices
Expand Down
4 changes: 2 additions & 2 deletions dpnp/tests/test_arraycreation.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,12 +861,12 @@ def test_full_order(order1, order2):
def test_full_strides():
a = numpy.full((3, 3), numpy.arange(3, dtype="i4"))
ia = dpnp.full((3, 3), dpnp.arange(3, dtype="i4"))
assert ia.strides == tuple(el // a.itemsize for el in a.strides)
assert ia.strides == a.strides
assert_array_equal(ia, a)

a = numpy.full((3, 3), numpy.arange(6, dtype="i4")[::2])
ia = dpnp.full((3, 3), dpnp.arange(6, dtype="i4")[::2])
assert ia.strides == tuple(el // a.itemsize for el in a.strides)
assert ia.strides == a.strides
assert_array_equal(ia, a)


Expand Down
8 changes: 4 additions & 4 deletions dpnp/tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def test_attributes(self):
assert_equal(self.three.shape, (10, 3, 2))
self.three.shape = (2, 5, 6)

assert_equal(self.one.strides, (self.one.itemsize / self.one.itemsize,))
num = self.two.itemsize / self.two.itemsize
assert_equal(self.one.strides, (self.one.itemsize,))
num = self.two.itemsize
assert_equal(self.two.strides, (5 * num, num))
num = self.three.itemsize / self.three.itemsize
num = self.three.itemsize
assert_equal(self.three.strides, (30 * num, 6 * num, num))

assert_equal(self.one.ndim, 1)
Expand Down Expand Up @@ -290,7 +290,7 @@ def test_flags_strides(dtype, order, strides):
(4, 4), dtype=dtype, order=order, strides=strides
)
a = numpy.ndarray((4, 4), dtype=dtype, order=order, strides=numpy_strides)
ia = dpnp.ndarray((4, 4), dtype=dtype, order=order, strides=strides)
ia = dpnp.ndarray((4, 4), dtype=dtype, order=order, strides=numpy_strides)
assert usm_array.flags == ia.flags
assert a.flags.c_contiguous == ia.flags.c_contiguous
assert a.flags.f_contiguous == ia.flags.f_contiguous
Expand Down
6 changes: 2 additions & 4 deletions dpnp/tests/third_party/cupy/core_tests/test_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import dpnp as cupy
from dpnp.tests.helper import (
has_support_aspect64,
is_win_platform,
numpy_version,
)
from dpnp.tests.third_party.cupy import testing

Expand Down Expand Up @@ -67,10 +65,10 @@ def test_copy_orders(self, order):
a = cupy.empty((2, 3, 4))
b = cupy.copy(a, order)

a_cpu = numpy.empty((2, 3, 4))
a_cpu = numpy.empty((2, 3, 4), dtype=a.dtype)
b_cpu = numpy.copy(a_cpu, order)

assert b.strides == tuple(x / b_cpu.itemsize for x in b_cpu.strides)
assert b.strides == b_cpu.strides


@pytest.mark.skip("`ElementwiseKernel` isn't supported")
Expand Down
Loading
Loading