Commit 340fca20 authored by Thomas Hartmann's avatar Thomas Hartmann
Browse files

Merge branch 'add-whosmat-function' into 'master'

Add whosmat() to list variables without loading data

Closes #23

See merge request !51
parents 8098bdcd fd0010bf
Loading
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -30,6 +30,6 @@
"""Read `.mat` files disregarding of the underlying file version."""

from ._version import __version__
from .pymatreader import read_mat
from .pymatreader import read_mat, whosmat

__all__ = ['read_mat', '__version__']
__all__ = ['read_mat', 'whosmat', '__version__']
+37 −2
Original line number Diff line number Diff line
@@ -41,12 +41,12 @@ try:
except ImportError:
    from scipy.io.matlab.miobase import get_matfile_version as matfile_version

from .utils import _hdf5todict, _import_h5py, _parse_scipy_mat_dict
from .utils import _hdf5todict, _import_h5py, _parse_scipy_mat_dict, _whosmat_hdf5

if TYPE_CHECKING:
    from collections.abc import Iterable

__all__ = ['read_mat']
__all__ = ['read_mat', 'whosmat']


def read_mat(
@@ -106,3 +106,38 @@ def read_mat(
        h5py = _import_h5py()
        with h5py.File(filename, 'r') as hdf5_file:
            return _hdf5todict(hdf5_file, variable_names=variable_names, ignore_fields=ignore_fields)


def whosmat(
    filename: str | Path,
) -> list[tuple[str, tuple[int, ...], str]]:
    """List variables in a MATLAB file without loading data.

    Works for all MATLAB file versions (v4 through v7.3). For older formats
    this delegates to :func:`scipy.io.whosmat`; for v7.3 (HDF5) files it
    inspects the HDF5 metadata directly.

    Parameters
    ----------
    filename : str | Path
        Path to the ``.mat`` file.

    Returns
    -------
    list of (name, shape, class) tuples
        Each tuple contains the variable name, its shape as a tuple of ints,
        and its MATLAB class as a string (e.g. ``'double'``, ``'char'``,
        ``'sparse'``).
    """
    filepath = Path(filename)
    if not filepath.exists():
        raise OSError(f'The file {filename} does not exist.')

    try:
        with filepath.open('rb') as fid:
            matfile_version(fid)
            return scipy.io.whosmat(fid)
    except NotImplementedError:
        h5py = _import_h5py()
        with h5py.File(filepath, 'r') as hdf5_file:
            return _whosmat_hdf5(hdf5_file)
+47 −1
Original line number Diff line number Diff line
@@ -78,6 +78,52 @@ def _import_h5py() -> h5py:
    return h5py


def _whosmat_hdf5(
    hdf5_file: h5py.File,
) -> list[tuple[str, tuple[int, ...], str]]:
    """List variables in an HDF5-based MATLAB v7.3 file without loading data.

    Parameters
    ----------
    hdf5_file : h5py.File
        An open HDF5 file handle.

    Returns
    -------
    list of (name, shape, class) tuples
        Each entry describes one top-level MATLAB variable.
    """
    h5py = _import_h5py()

    result: list[tuple[str, tuple[int, ...], str]] = []

    for key in hdf5_file:
        if key == '#refs#':
            continue

        obj = hdf5_file[key]
        matlab_class = obj.attrs.get('MATLAB_class', b'unknown').decode()

        if isinstance(obj, h5py.Dataset):
            shape = tuple(int(x) for x in obj[()]) if 'MATLAB_empty' in obj.attrs else tuple(reversed(obj.shape))
        elif isinstance(obj, h5py.Group):
            if 'MATLAB_sparse' in obj.attrs:
                matlab_class = 'sparse'
                M = int(obj.attrs['MATLAB_sparse'])  # noqa: N806
                jc = obj.get('jc')
                N = len(jc) - 1 if jc is not None else 0  # noqa: N806
                shape = (M, N)
            else:
                shape = (1, 1)  # structs / cell arrays stored as groups
        else:
            continue

        result.append((key, shape, matlab_class))

    result.sort(key=lambda x: x[0])
    return result


def _hdf5todict(
    hdf5_object: h5py.Dataset | h5py.Group,
    variable_names: Iterable | None = None,
@@ -272,7 +318,7 @@ def _parse_scipy_mat_dict(data: dict) -> dict:


def _check_for_scipy_mat_struct(
    data: dict | np.ndarray | spmatrix | MatlabOpaque
    data: dict | np.ndarray | spmatrix | MatlabOpaque,
) -> dict | np.ndarray | csc_array | list | None:
    """
    Check all entries of data for occurrences of scipy.io.matlab.mio5_params.mat_struct and convert them.
+60 −1
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ import numpy as np
import pytest
from scipy import sparse

from pymatreader import read_mat
from pymatreader import read_mat, whosmat

from .helper_functions import _read_xml_data, _sanitize_dict, assertDeepAlmostEqual

@@ -339,3 +339,62 @@ def test_sparse_matrices(version):

            np.testing.assert_allclose(matrix.toarray(), csv_matrix, atol=1e-15)


def test_whosmat_v7():
    """Test whosmat on a v7 .mat file."""
    result = whosmat(Path(test_data_folder, testdata_v7_fname))
    assert isinstance(result, list)
    assert all(isinstance(entry, tuple) and len(entry) == 3 for entry in result)  # noqa: PLR2004

    by_name = {name: (shape, cls) for name, shape, cls in result}
    assert 'a_matrix' in by_name
    assert by_name['a_matrix'] == ((100, 100), 'double')
    assert by_name['a_float'] == ((1, 1), 'double')


def test_whosmat_v73():
    """Test whosmat on a v7.3 (HDF5) .mat file."""
    result = whosmat(Path(test_data_folder, testdata_v73_fname))
    assert isinstance(result, list)
    assert all(isinstance(entry, tuple) and len(entry) == 3 for entry in result)  # noqa: PLR2004

    by_name = {name: (shape, cls) for name, shape, cls in result}
    assert 'a_matrix' in by_name
    assert by_name['a_matrix'] == ((100, 100), 'double')
    assert by_name['a_float'] == ((1, 1), 'double')


def test_whosmat_v7_v73_consistency():
    """Test that whosmat returns consistent variable names and classes across versions."""
    v7_result = whosmat(Path(test_data_folder, testdata_v7_fname))
    v73_result = whosmat(Path(test_data_folder, testdata_v73_fname))

    v7_names = {name for name, _, _ in v7_result}
    v73_names = {name for name, _, _ in v73_result}
    assert v7_names == v73_names

    v7_by_name = {name: cls for name, _, cls in v7_result}
    v73_by_name = {name: cls for name, _, cls in v73_result}
    for name in v7_names:
        assert v7_by_name[name] == v73_by_name[name], f'class mismatch for {name}'


@pytest.mark.parametrize('version', ['4', '6', '7', '73'])
def test_whosmat_sparse(version):
    """Test that whosmat reports sparse matrices correctly."""
    result = whosmat(Path(test_data_folder, f'sparse_v{version}.mat'))
    by_name = {name: (shape, cls) for name, shape, cls in result}

    assert by_name['A_square'] == ((10, 10), 'sparse')
    assert by_name['A_tall'] == ((20, 10), 'sparse')
    assert by_name['A_wide'] == ((10, 20), 'sparse')
    assert by_name['A_col'] == ((10, 1), 'sparse')
    assert by_name['A_row'] == ((1, 10), 'sparse')
    assert by_name['A_single'] == ((1, 1), 'sparse')


def test_whosmat_file_not_found():
    """Test that whosmat raises OSError for a missing file."""
    with pytest.raises(OSError):
        whosmat(Path(test_data_folder, invalid_fname))