graphomotor.plot.spiral_plots

Spiral visualization functions for quality control and data inspection.

This module provides plotting functions for visualizing spiral drawing trajectories from CSV files. The functions support both single spiral plotting and batch processing for quality control purposes.

Plot Types

  • Single spiral plots: Visualize individual spiral drawings with optional reference spiral overlay and color-coded line segments.
  • Batch spiral plots: Visualize multiple spiral drawings organized by metadata with configurable subplot density and arrangement.
  1"""Spiral visualization functions for quality control and data inspection.
  2
  3This module provides plotting functions for visualizing spiral drawing trajectories from
  4CSV files. The functions support both single spiral plotting and batch processing for
  5quality control purposes.
  6
  7Plot Types
  8----------
  9- **Single spiral plots**: Visualize individual spiral drawings with optional
 10  reference spiral overlay and color-coded line segments.
 11- **Batch spiral plots**: Visualize multiple spiral drawings organized by metadata
 12  with configurable subplot density and arrangement.
 13"""
 14
 15import pathlib
 16
 17import matplotlib
 18import numpy as np
 19from matplotlib import pyplot as plt
 20
 21from graphomotor.core import config, models
 22from graphomotor.io import reader
 23from graphomotor.utils import center_spiral, plotting
 24
 25matplotlib.use("agg")  # prevent interactive matplotlib
 26logger = config.get_logger()
 27
 28
 29def _plot_spiral(
 30    ax: plt.Axes,
 31    spiral: models.Spiral,
 32    centered_ref: np.ndarray,
 33    include_reference: bool = False,
 34    color_segments: bool = False,
 35    is_batch: bool = False,
 36) -> None:
 37    """Plot a spiral on a given matplotlib Axes.
 38
 39    Args:
 40        ax: Matplotlib Axes to plot on.
 41        spiral: Spiral object containing drawing data and metadata.
 42        centered_ref: Pre-computed centered reference spiral coordinates.
 43        include_reference: If True, overlays the reference spiral for comparison.
 44        color_segments: If True, colors each line segment with distinct colors.
 45        is_batch: If True, formats for batch mode (no legend, no axis labels,
 46            smaller font).
 47    """
 48    centered_spiral = center_spiral.center_spiral(spiral)
 49    x_coords = centered_spiral.data["x"].values
 50    y_coords = centered_spiral.data["y"].values
 51    line_numbers = centered_spiral.data["line_number"].values
 52
 53    if color_segments:
 54        unique_line_numbers = np.unique(line_numbers)
 55        color_indices = np.linspace(0, 1, len(unique_line_numbers))
 56
 57        if len(unique_line_numbers) <= 10:
 58            colors = plt.get_cmap("tab10")(color_indices)
 59        elif len(unique_line_numbers) <= 20:
 60            colors = plt.get_cmap("tab20")(color_indices)
 61        else:
 62            colors = plt.get_cmap("viridis")(color_indices)
 63
 64        color_map = dict(zip(unique_line_numbers, colors))
 65
 66        for i in range(len(x_coords) - 1):
 67            current_line = line_numbers[i]
 68            label = (
 69                f"Line {int(current_line)}"
 70                if i == 0 or line_numbers[i - 1] != current_line
 71                else None
 72            )
 73            ax.plot(
 74                [x_coords[i], x_coords[i + 1]],
 75                [y_coords[i], y_coords[i + 1]],
 76                color=color_map[current_line],
 77                linewidth=2 if not is_batch else 1,
 78                alpha=0.8,
 79                label=label,
 80            )
 81    else:
 82        ax.plot(
 83            x_coords,
 84            y_coords,
 85            "tab:blue",
 86            linewidth=1.5,
 87            alpha=0.8,
 88            label="Drawn spiral",
 89        )
 90
 91    ax.plot(
 92        centered_ref[0, 0],
 93        centered_ref[0, 1],
 94        "go",
 95        markersize=12 if not is_batch else 6,
 96        label="Start" if not is_batch else None,
 97        zorder=5,
 98        alpha=0.3,
 99    )
100    ax.plot(
101        centered_ref[-1, 0],
102        centered_ref[-1, 1],
103        "ro",
104        markersize=12 if not is_batch else 6,
105        label="End" if not is_batch else None,
106        zorder=5,
107        alpha=0.3,
108    )
109
110    if include_reference:
111        ax.plot(
112            centered_ref[:, 0],
113            centered_ref[:, 1],
114            "k--",
115            linewidth=6 if not is_batch else 3,
116            alpha=0.15,
117            label="Reference spiral" if not is_batch else None,
118        )
119
120    participant_id, task, hand, start_time = plotting.extract_spiral_metadata(spiral)
121
122    ax.set_title(
123        label=f"ID: {participant_id}\n{task} - {hand}\n{start_time}",
124        fontsize=8 if is_batch else 14,
125    )
126
127    ax.set_aspect("equal")
128
129    if not is_batch:
130        ax.set_xlabel("X Position (pixels)", fontsize=12)
131        ax.set_ylabel("Y Position (pixels)", fontsize=12)
132        ax.legend()
133        ax.grid(True, alpha=0.3)
134    else:
135        ax.set_xticks([])
136        ax.set_yticks([])
137
138
139def plot_single_spiral(
140    data: str | pathlib.Path | models.Spiral,
141    output_path: str | pathlib.Path | None = None,
142    include_reference: bool = False,
143    color_segments: bool = False,
144    spiral_config: config.SpiralConfig | None = None,
145) -> plt.Figure:
146    """Plot a single spiral drawing with optional reference spiral and color coding.
147
148    This function creates a visualization of an individual spiral drawing trajectory.
149    The spiral can be colored with distinct segments for better visualization of
150    drawing progression, and optionally overlaid with a reference spiral for
151    comparison.
152
153    Args:
154        data: Path to CSV file containing spiral data, or a loaded Spiral object.
155        output_path: Optional directory where the figure will be saved. If None,
156            the function only returns the figure without saving.
157        include_reference: If True, overlays the reference spiral for comparison.
158        color_segments: If True, colors each line segment with distinct colors.
159        spiral_config: Configuration for reference spiral generation. If None,
160            uses default configuration.
161
162    Returns:
163        The matplotlib Figure object.
164
165    Raises:
166        ValueError: If the input data is invalid or cannot be loaded.
167    """
168    logger.debug("Starting single spiral plot generation")
169
170    if isinstance(data, (str, pathlib.Path)):
171        try:
172            spiral = reader.load_spiral(data)
173        except Exception as e:
174            error_msg = f"Failed to load spiral data from {data}: {e}"
175            logger.error(error_msg)
176            raise ValueError(error_msg) from e
177    elif isinstance(data, models.Spiral):
178        spiral = data
179    else:
180        error_msg = f"Invalid data type: {type(data)}. Expected str, Path, or Spiral."
181        logger.error(error_msg)
182        raise ValueError(error_msg)
183
184    centered_ref = plotting.get_reference_spiral(spiral_config)
185
186    fig, ax = plt.subplots(figsize=(10, 10))
187
188    _plot_spiral(
189        ax=ax,
190        spiral=spiral,
191        centered_ref=centered_ref,
192        include_reference=include_reference,
193        color_segments=color_segments,
194        is_batch=False,
195    )
196
197    plt.tight_layout()
198
199    if output_path:
200        participant_id, task, hand, _ = plotting.extract_spiral_metadata(spiral)
201        filename = f"spiral_{participant_id}_{task}_{hand}"
202        plotting.save_figure(figure=fig, output_path=output_path, filename=filename)
203
204    return fig
205
206
207def plot_batch_spirals(
208    data: str | pathlib.Path,
209    output_path: str | pathlib.Path | None = None,
210    include_reference: bool = False,
211    color_segments: bool = False,
212    spiral_config: config.SpiralConfig | None = None,
213) -> plt.Figure:
214    """Plot multiple spirals in a batch using a structured grid layout.
215
216    This function processes multiple spiral CSV files from a directory and creates
217    a structured grid of spiral visualizations with rows for participant/hand
218    combinations and columns for tasks.
219
220    Args:
221        data: Path to directory containing spiral CSV files.
222        output_path: Optional directory where the figure will be saved. If None,
223            the function only returns the figure without saving.
224        include_reference: If True, overlays reference spirals for comparison.
225        color_segments: If True, colors each line segment with distinct colors.
226        spiral_config: Configuration for reference spiral generation. If None,
227            uses default configuration.
228
229    Returns:
230        The matplotlib Figure object containing all spiral plots.
231
232    Raises:
233        ValueError: If the input directory doesn't exist.
234    """
235    logger.debug("Starting batch spiral plot generation")
236
237    data = pathlib.Path(data)
238    if not data.exists() or not data.is_dir():
239        error_msg = f"Input path does not exist or is not a directory: {data}"
240        logger.error(error_msg)
241        raise ValueError(error_msg)
242
243    try:
244        spirals = plotting.load_spirals_from_directory(data)
245    except ValueError as e:
246        logger.error(f"Failed to load any spirals from directory: {e}")
247        raise
248
249    if len(spirals) == 1:
250        return plot_single_spiral(
251            data=spirals[0],
252            output_path=output_path,
253            include_reference=include_reference,
254            color_segments=color_segments,
255            spiral_config=spiral_config,
256        )
257
258    centered_ref = plotting.get_reference_spiral(spiral_config)
259
260    spiral_grid, participant_hand_combos, sorted_tasks = (
261        plotting.index_spirals_by_metadata(spirals)
262    )
263
264    n_rows = len(participant_hand_combos)
265    n_cols = len(sorted_tasks)
266
267    fig, axes = plotting.create_grid_layout(n_rows, n_cols)
268
269    for row_idx, (participant, hand) in enumerate(participant_hand_combos):
270        for col_idx, task in enumerate(sorted_tasks):
271            ax = axes[row_idx, col_idx]
272            key = (participant, hand, task)
273
274            if key in spiral_grid:
275                spiral = spiral_grid[key]
276                _plot_spiral(
277                    ax=ax,
278                    spiral=spiral,
279                    centered_ref=centered_ref,
280                    include_reference=include_reference,
281                    color_segments=color_segments,
282                    is_batch=True,
283                )
284            else:
285                ax.text(
286                    0.5,
287                    0.5,
288                    "No Data",
289                    ha="center",
290                    va="center",
291                    transform=ax.transAxes,
292                    fontsize=12,
293                    alpha=0.5,
294                )
295                ax.set_xticks([])
296                ax.set_yticks([])
297                ax.set_aspect("equal")
298                ax.set_title("No Data", fontsize=8)
299
300            plotting.add_grid_labels(ax, key, row_idx, col_idx)
301
302    if output_path:
303        filename = "batch_spirals"
304        plotting.save_figure(figure=fig, output_path=output_path, filename=filename)
305
306    return fig
logger = <Logger graphomotor (WARNING)>
def plot_single_spiral( data: str | pathlib._local.Path | graphomotor.core.models.Spiral, output_path: str | pathlib._local.Path | None = None, include_reference: bool = False, color_segments: bool = False, spiral_config: graphomotor.core.config.SpiralConfig | None = None) -> matplotlib.figure.Figure:
140def plot_single_spiral(
141    data: str | pathlib.Path | models.Spiral,
142    output_path: str | pathlib.Path | None = None,
143    include_reference: bool = False,
144    color_segments: bool = False,
145    spiral_config: config.SpiralConfig | None = None,
146) -> plt.Figure:
147    """Plot a single spiral drawing with optional reference spiral and color coding.
148
149    This function creates a visualization of an individual spiral drawing trajectory.
150    The spiral can be colored with distinct segments for better visualization of
151    drawing progression, and optionally overlaid with a reference spiral for
152    comparison.
153
154    Args:
155        data: Path to CSV file containing spiral data, or a loaded Spiral object.
156        output_path: Optional directory where the figure will be saved. If None,
157            the function only returns the figure without saving.
158        include_reference: If True, overlays the reference spiral for comparison.
159        color_segments: If True, colors each line segment with distinct colors.
160        spiral_config: Configuration for reference spiral generation. If None,
161            uses default configuration.
162
163    Returns:
164        The matplotlib Figure object.
165
166    Raises:
167        ValueError: If the input data is invalid or cannot be loaded.
168    """
169    logger.debug("Starting single spiral plot generation")
170
171    if isinstance(data, (str, pathlib.Path)):
172        try:
173            spiral = reader.load_spiral(data)
174        except Exception as e:
175            error_msg = f"Failed to load spiral data from {data}: {e}"
176            logger.error(error_msg)
177            raise ValueError(error_msg) from e
178    elif isinstance(data, models.Spiral):
179        spiral = data
180    else:
181        error_msg = f"Invalid data type: {type(data)}. Expected str, Path, or Spiral."
182        logger.error(error_msg)
183        raise ValueError(error_msg)
184
185    centered_ref = plotting.get_reference_spiral(spiral_config)
186
187    fig, ax = plt.subplots(figsize=(10, 10))
188
189    _plot_spiral(
190        ax=ax,
191        spiral=spiral,
192        centered_ref=centered_ref,
193        include_reference=include_reference,
194        color_segments=color_segments,
195        is_batch=False,
196    )
197
198    plt.tight_layout()
199
200    if output_path:
201        participant_id, task, hand, _ = plotting.extract_spiral_metadata(spiral)
202        filename = f"spiral_{participant_id}_{task}_{hand}"
203        plotting.save_figure(figure=fig, output_path=output_path, filename=filename)
204
205    return fig

Plot a single spiral drawing with optional reference spiral and color coding.

This function creates a visualization of an individual spiral drawing trajectory. The spiral can be colored with distinct segments for better visualization of drawing progression, and optionally overlaid with a reference spiral for comparison.

Arguments:
  • data: Path to CSV file containing spiral data, or a loaded Spiral object.
  • output_path: Optional directory where the figure will be saved. If None, the function only returns the figure without saving.
  • include_reference: If True, overlays the reference spiral for comparison.
  • color_segments: If True, colors each line segment with distinct colors.
  • spiral_config: Configuration for reference spiral generation. If None, uses default configuration.
Returns:

The matplotlib Figure object.

Raises:
  • ValueError: If the input data is invalid or cannot be loaded.
def plot_batch_spirals( data: str | pathlib._local.Path, output_path: str | pathlib._local.Path | None = None, include_reference: bool = False, color_segments: bool = False, spiral_config: graphomotor.core.config.SpiralConfig | None = None) -> matplotlib.figure.Figure:
208def plot_batch_spirals(
209    data: str | pathlib.Path,
210    output_path: str | pathlib.Path | None = None,
211    include_reference: bool = False,
212    color_segments: bool = False,
213    spiral_config: config.SpiralConfig | None = None,
214) -> plt.Figure:
215    """Plot multiple spirals in a batch using a structured grid layout.
216
217    This function processes multiple spiral CSV files from a directory and creates
218    a structured grid of spiral visualizations with rows for participant/hand
219    combinations and columns for tasks.
220
221    Args:
222        data: Path to directory containing spiral CSV files.
223        output_path: Optional directory where the figure will be saved. If None,
224            the function only returns the figure without saving.
225        include_reference: If True, overlays reference spirals for comparison.
226        color_segments: If True, colors each line segment with distinct colors.
227        spiral_config: Configuration for reference spiral generation. If None,
228            uses default configuration.
229
230    Returns:
231        The matplotlib Figure object containing all spiral plots.
232
233    Raises:
234        ValueError: If the input directory doesn't exist.
235    """
236    logger.debug("Starting batch spiral plot generation")
237
238    data = pathlib.Path(data)
239    if not data.exists() or not data.is_dir():
240        error_msg = f"Input path does not exist or is not a directory: {data}"
241        logger.error(error_msg)
242        raise ValueError(error_msg)
243
244    try:
245        spirals = plotting.load_spirals_from_directory(data)
246    except ValueError as e:
247        logger.error(f"Failed to load any spirals from directory: {e}")
248        raise
249
250    if len(spirals) == 1:
251        return plot_single_spiral(
252            data=spirals[0],
253            output_path=output_path,
254            include_reference=include_reference,
255            color_segments=color_segments,
256            spiral_config=spiral_config,
257        )
258
259    centered_ref = plotting.get_reference_spiral(spiral_config)
260
261    spiral_grid, participant_hand_combos, sorted_tasks = (
262        plotting.index_spirals_by_metadata(spirals)
263    )
264
265    n_rows = len(participant_hand_combos)
266    n_cols = len(sorted_tasks)
267
268    fig, axes = plotting.create_grid_layout(n_rows, n_cols)
269
270    for row_idx, (participant, hand) in enumerate(participant_hand_combos):
271        for col_idx, task in enumerate(sorted_tasks):
272            ax = axes[row_idx, col_idx]
273            key = (participant, hand, task)
274
275            if key in spiral_grid:
276                spiral = spiral_grid[key]
277                _plot_spiral(
278                    ax=ax,
279                    spiral=spiral,
280                    centered_ref=centered_ref,
281                    include_reference=include_reference,
282                    color_segments=color_segments,
283                    is_batch=True,
284                )
285            else:
286                ax.text(
287                    0.5,
288                    0.5,
289                    "No Data",
290                    ha="center",
291                    va="center",
292                    transform=ax.transAxes,
293                    fontsize=12,
294                    alpha=0.5,
295                )
296                ax.set_xticks([])
297                ax.set_yticks([])
298                ax.set_aspect("equal")
299                ax.set_title("No Data", fontsize=8)
300
301            plotting.add_grid_labels(ax, key, row_idx, col_idx)
302
303    if output_path:
304        filename = "batch_spirals"
305        plotting.save_figure(figure=fig, output_path=output_path, filename=filename)
306
307    return fig

Plot multiple spirals in a batch using a structured grid layout.

This function processes multiple spiral CSV files from a directory and creates a structured grid of spiral visualizations with rows for participant/hand combinations and columns for tasks.

Arguments:
  • data: Path to directory containing spiral CSV files.
  • output_path: Optional directory where the figure will be saved. If None, the function only returns the figure without saving.
  • include_reference: If True, overlays reference spirals for comparison.
  • color_segments: If True, colors each line segment with distinct colors.
  • spiral_config: Configuration for reference spiral generation. If None, uses default configuration.
Returns:

The matplotlib Figure object containing all spiral plots.

Raises:
  • ValueError: If the input directory doesn't exist.