expand_dims¶
- array_api_extra.expand_dims(a, /, *, axis=(0,), xp=None)¶
Expand the shape of an array.
Deprecated since version 0.11.0:
expand_dims()is deprecated and will be removed in v1.0.0.array_api.expand_dims()with support for a tuple of ints in axis exists in the standard as of v2025.12.Insert (a) new axis/axes that will appear at the position(s) specified by axis in the expanded array shape.
- Parameters:
a (
object) – Array to have its shape expanded.axis (
int|tuple[int,...]) – Position(s) in the expanded axes where the new axis (or axes) is/are placed. If multiple positions are provided, they should be unique (note that a position given by a positive index could also be referred to by a negative index - that will also result in an error). Default:(0,).xp (
ModuleType|None) – The standard-compatible namespace for a. Default: infer.
- Returns:
a with an expanded shape.
- Return type:
Examples
>>> import array_api_strict as xp >>> import array_api_extra as xpx >>> x = xp.asarray([1, 2]) >>> x.shape (2,)
The following is equivalent to
x[xp.newaxis, :]orx[xp.newaxis]:>>> y = xpx.expand_dims(x, axis=0, xp=xp) >>> y Array([[1, 2]], dtype=array_api_strict.int64) >>> y.shape (1, 2)
The following is equivalent to
x[:, xp.newaxis]:>>> y = xpx.expand_dims(x, axis=1, xp=xp) >>> y Array([[1], [2]], dtype=array_api_strict.int64) >>> y.shape (2, 1)
axismay also be a tuple:>>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp) >>> y Array([[[1, 2]]], dtype=array_api_strict.int64)
>>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp) >>> y Array([[[1], [2]]], dtype=array_api_strict.int64)