Source code for kicad_sch_api.collections.wires

"""
Enhanced wire management with IndexRegistry integration.

Provides WireCollection using BaseCollection infrastructure with
endpoint indexing and wire geometry queries.
"""

import logging
import uuid as uuid_module
from typing import Any, Dict, List, Optional, Tuple, Union

from ..core.types import Point, Wire, WireType
from .base import BaseCollection, IndexSpec, ValidationLevel

logger = logging.getLogger(__name__)


[docs] class WireCollection(BaseCollection[Wire]): """ Wire collection with endpoint indexing and geometry queries. Inherits from BaseCollection for UUID indexing and adds wire-specific functionality for endpoint queries and wire type filtering. Features: - Fast UUID lookup via IndexRegistry - Multi-point wire support - Endpoint-based queries - Horizontal/vertical wire detection - Lazy index rebuilding - Batch mode support """
[docs] def __init__( self, wires: Optional[List[Wire]] = None, validation_level: ValidationLevel = ValidationLevel.NORMAL, ): """ Initialize wire collection. Args: wires: Initial list of wires validation_level: Validation level for operations """ super().__init__(validation_level=validation_level) # Add initial wires if wires: with self.batch_mode(): for wire in wires: super().add(wire) logger.debug(f"WireCollection initialized with {len(self)} wires")
# BaseCollection abstract method implementations def _get_item_uuid(self, item: Wire) -> str: """Extract UUID from wire.""" return item.uuid def _create_item(self, **kwargs) -> Wire: """Create a new wire (not typically used directly).""" raise NotImplementedError("Use add() method to create wires") def _get_index_specs(self) -> List[IndexSpec]: """Get index specifications for wire collection.""" return [ IndexSpec( name="uuid", key_func=lambda w: w.uuid, unique=True, description="UUID index for fast lookups", ), ] # Wire-specific add method
[docs] def add( self, start: Optional[Union[Point, Tuple[float, float]]] = None, end: Optional[Union[Point, Tuple[float, float]]] = None, points: Optional[List[Union[Point, Tuple[float, float]]]] = None, wire_type: WireType = WireType.WIRE, stroke_width: float = 0.0, uuid: Optional[str] = None, ) -> str: """ Add a wire to the collection. Args: start: Start point (for simple wires) end: End point (for simple wires) points: List of points (for multi-point wires) wire_type: Wire type (wire or bus) stroke_width: Line width uuid: Optional UUID (auto-generated if not provided) Returns: UUID of the created wire Raises: ValueError: If neither start/end nor points are provided ValueError: If UUID already exists """ # Generate UUID if not provided if uuid is None: uuid = str(uuid_module.uuid4()) else: # Check for duplicate self._ensure_indexes_current() if self._index_registry.has_key("uuid", uuid): raise ValueError(f"Wire with UUID '{uuid}' already exists") # Convert points wire_points = [] if points: # Multi-point wire for point in points: if isinstance(point, tuple): wire_points.append(Point(point[0], point[1])) else: wire_points.append(point) elif start is not None and end is not None: # Simple 2-point wire if isinstance(start, tuple): start = Point(start[0], start[1]) if isinstance(end, tuple): end = Point(end[0], end[1]) wire_points = [start, end] else: raise ValueError("Must provide either start/end points or points list") # Validate wire has at least 2 points if len(wire_points) < 2: raise ValueError("Wire must have at least 2 points") # Create wire wire = Wire(uuid=uuid, points=wire_points, wire_type=wire_type, stroke_width=stroke_width) # Add to collection super().add(wire) logger.debug(f"Added wire: {len(wire_points)} points, UUID={uuid}") return uuid
# Endpoint-based queries
[docs] def get_by_endpoint( self, point: Union[Point, Tuple[float, float]], tolerance: float = 0.01 ) -> List[Wire]: """ Find all wires with an endpoint near a given point. Args: point: Point to search for tolerance: Distance tolerance for matching Returns: List of wires with endpoint near the point """ if isinstance(point, tuple): point = Point(point[0], point[1]) matching_wires = [] for wire in self._items: # Check first and last point (endpoints) if ( wire.points[0].distance_to(point) <= tolerance or wire.points[-1].distance_to(point) <= tolerance ): matching_wires.append(wire) return matching_wires
[docs] def get_at_point( self, point: Union[Point, Tuple[float, float]], tolerance: float = 0.01 ) -> List[Wire]: """ Find all wires that pass through or near a point. Args: point: Point to search for tolerance: Distance tolerance for matching Returns: List of wires passing through the point """ if isinstance(point, tuple): point = Point(point[0], point[1]) matching_wires = [] for wire in self._items: # Check if any point in wire is near the search point for wire_point in wire.points: if wire_point.distance_to(point) <= tolerance: matching_wires.append(wire) break # Found match, move to next wire return matching_wires
# Wire geometry queries
[docs] def get_horizontal(self) -> List[Wire]: """ Get all horizontal wires (Y coordinates equal). Returns: List of horizontal wires """ horizontal = [] for wire in self._items: if len(wire.points) == 2: # Simple 2-point wire if abs(wire.points[0].y - wire.points[1].y) < 0.01: horizontal.append(wire) return horizontal
[docs] def get_vertical(self) -> List[Wire]: """ Get all vertical wires (X coordinates equal). Returns: List of vertical wires """ vertical = [] for wire in self._items: if len(wire.points) == 2: # Simple 2-point wire if abs(wire.points[0].x - wire.points[1].x) < 0.01: vertical.append(wire) return vertical
[docs] def get_by_type(self, wire_type: WireType) -> List[Wire]: """ Get all wires of a specific type. Args: wire_type: Wire type to filter by Returns: List of wires matching the type """ return [w for w in self._items if w.wire_type == wire_type]
# Statistics
[docs] def get_statistics(self) -> Dict[str, Any]: """ Get wire collection statistics. Returns: Dictionary with wire statistics """ if not self._items: base_stats = super().get_statistics() base_stats.update( { "total_wires": 0, "total_segments": 0, "wire_count": 0, "bus_count": 0, "horizontal_count": 0, "vertical_count": 0, "avg_points_per_wire": 0, } ) return base_stats wire_count = sum(1 for w in self._items if w.wire_type == WireType.WIRE) bus_count = sum(1 for w in self._items if w.wire_type == WireType.BUS) total_segments = sum(len(w.points) - 1 for w in self._items) avg_points = sum(len(w.points) for w in self._items) / len(self._items) horizontal = len(self.get_horizontal()) vertical = len(self.get_vertical()) base_stats = super().get_statistics() base_stats.update( { "total_wires": len(self._items), "total_segments": total_segments, "wire_count": wire_count, "bus_count": bus_count, "horizontal_count": horizontal, "vertical_count": vertical, "avg_points_per_wire": avg_points, } ) return base_stats
# Compatibility methods @property def modified(self) -> bool: """Check if collection has been modified (compatibility).""" return self.is_modified
[docs] def mark_saved(self) -> None: """Mark collection as saved (reset modified flag).""" self.mark_clean()