Skip to content

Multimodal Fusion — Cross-Sensor Spike Train Merging

Fuse spike trains from multiple sensor modalities (vision/DVS, audio/cochlea, IMU) into a unified representation. Handles different time resolutions, firing rates, and channel counts.

Fusion Modes

Mode Description Output Channels
concatenate Stack channels from all modalities sum(n_channels)
sum Element-wise OR (any-modality spike), pad smaller modalities max(n_channels)
attention Learned cross-modal weighting per modality sum(n_channels)

All modes include automatic timebase resampling (bin mapping from modality dt to output dt) and rate normalization (scale so max rate maps to 1.0).

Components

  • ModalityConfig — Configuration for one sensor modality.
Field Type Meaning
name str Modality identifier (e.g., "dvs", "audio")
n_channels int Channel count
dt_us float Time bin width in microseconds
max_rate_hz float Maximum expected firing rate (default 1000)
  • MultiModalFusion — Main fusion engine.
Parameter Default Meaning
modalities (required) List of ModalityConfig
output_dt_us 1000.0 Output time bin width (common timebase)
mode "concatenate" Fusion mode

Usage

from sc_neurocore.fusion.multimodal import ModalityConfig, MultiModalFusion
import numpy as np

# Define sensor modalities
dvs = ModalityConfig("dvs", n_channels=128, dt_us=100.0)
audio = ModalityConfig("audio", n_channels=64, dt_us=500.0)

# Create fuser
fuser = MultiModalFusion([dvs, audio], output_dt_us=100.0, mode="concatenate")

# Fuse spike trains
spikes = {
    "dvs": np.random.randint(0, 2, (100, 128)),
    "audio": np.random.randint(0, 2, (100, 64)),
}
fused = fuser.fuse(spikes, duration_us=10000.0)
print(f"Output shape: {fused.shape}")  # (100, 192)

# Missing modality → zero-filled
fused_partial = fuser.fuse({"dvs": spikes["dvs"]}, duration_us=10000.0)

See Tutorial 49: Multimodal Fusion.

sc_neurocore.fusion

Multimodal spike train fusion for cross-sensor SNN processing.

MultiModalFusion

Fuse spike trains from multiple sensor modalities.

Parameters

modalities : list of ModalityConfig Sensor modality definitions. output_dt_us : float Output time bin width in microseconds (common timebase). mode : str Fusion mode: 'concatenate', 'sum', or 'attention'.

Source code in src/sc_neurocore/fusion/multimodal.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class MultiModalFusion:
    """Fuse spike trains from multiple sensor modalities.

    Parameters
    ----------
    modalities : list of ModalityConfig
        Sensor modality definitions.
    output_dt_us : float
        Output time bin width in microseconds (common timebase).
    mode : str
        Fusion mode: 'concatenate', 'sum', or 'attention'.
    """

    def __init__(
        self,
        modalities: list[ModalityConfig],
        output_dt_us: float = 1000.0,
        mode: str = "concatenate",
    ):
        self.modalities = modalities
        self.output_dt_us = output_dt_us
        self.mode = mode

        if mode == "concatenate":
            self.n_output = sum(m.n_channels for m in modalities)
        elif mode == "sum":
            max_ch = max(m.n_channels for m in modalities)
            self.n_output = max_ch
        elif mode == "attention":
            self.n_output = sum(m.n_channels for m in modalities)
            n_mod = len(modalities)
            self.attention_weights = np.ones(n_mod) / n_mod
        else:
            raise ValueError(f"Unknown mode '{mode}'")

    def fuse(self, spike_trains: dict[str, np.ndarray], duration_us: float) -> np.ndarray:
        """Fuse spike trains from all modalities into a unified output.

        Parameters
        ----------
        spike_trains : dict mapping modality name to spike matrix
            Each matrix has shape (n_bins_modality, n_channels_modality).
        duration_us : float
            Total duration in microseconds.

        Returns
        -------
        ndarray of shape (n_output_bins, n_output_channels)
        """
        n_output_bins = max(1, int(np.ceil(duration_us / self.output_dt_us)))

        resampled = []
        for mod in self.modalities:
            if mod.name not in spike_trains:
                resampled.append(np.zeros((n_output_bins, mod.n_channels), dtype=np.float64))
                continue

            spikes = spike_trains[mod.name]
            n_bins_in = spikes.shape[0]

            # Resample to output timebase
            if n_bins_in == n_output_bins:
                resampled.append(spikes.astype(np.float64))
            else:
                # Linear resampling via bin mapping
                out = np.zeros((n_output_bins, mod.n_channels), dtype=np.float64)
                ratio = n_bins_in / max(n_output_bins, 1)
                for t_out in range(n_output_bins):
                    t_in_start = int(t_out * ratio)
                    t_in_end = min(int((t_out + 1) * ratio), n_bins_in)
                    if t_in_start < t_in_end:
                        out[t_out] = spikes[t_in_start:t_in_end].max(axis=0)
                resampled.append(out)

            # Rate normalization: scale so max rate maps to 1.0
            r = resampled[-1]
            max_val = r.max()
            if max_val > 0:
                resampled[-1] = r / max_val

        if self.mode == "concatenate":
            return np.concatenate(resampled, axis=1)

        if self.mode == "sum":
            # Pad smaller modalities and combine
            max_ch = self.n_output
            padded = []
            for r in resampled:
                if r.shape[1] < max_ch:
                    pad = np.zeros((r.shape[0], max_ch - r.shape[1]))
                    padded.append(np.concatenate([r, pad], axis=1))
                else:
                    padded.append(r[:, :max_ch])
            return np.clip(sum(padded), 0, 1)

        if self.mode == "attention":
            weighted = []
            for i, r in enumerate(resampled):
                weighted.append(r * self.attention_weights[i])
            return np.concatenate(weighted, axis=1)

        raise ValueError(f"Unknown mode '{self.mode}'")  # pragma: no cover

fuse(spike_trains, duration_us)

Fuse spike trains from all modalities into a unified output.

Parameters

spike_trains : dict mapping modality name to spike matrix Each matrix has shape (n_bins_modality, n_channels_modality). duration_us : float Total duration in microseconds.

Returns

ndarray of shape (n_output_bins, n_output_channels)

Source code in src/sc_neurocore/fusion/multimodal.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def fuse(self, spike_trains: dict[str, np.ndarray], duration_us: float) -> np.ndarray:
    """Fuse spike trains from all modalities into a unified output.

    Parameters
    ----------
    spike_trains : dict mapping modality name to spike matrix
        Each matrix has shape (n_bins_modality, n_channels_modality).
    duration_us : float
        Total duration in microseconds.

    Returns
    -------
    ndarray of shape (n_output_bins, n_output_channels)
    """
    n_output_bins = max(1, int(np.ceil(duration_us / self.output_dt_us)))

    resampled = []
    for mod in self.modalities:
        if mod.name not in spike_trains:
            resampled.append(np.zeros((n_output_bins, mod.n_channels), dtype=np.float64))
            continue

        spikes = spike_trains[mod.name]
        n_bins_in = spikes.shape[0]

        # Resample to output timebase
        if n_bins_in == n_output_bins:
            resampled.append(spikes.astype(np.float64))
        else:
            # Linear resampling via bin mapping
            out = np.zeros((n_output_bins, mod.n_channels), dtype=np.float64)
            ratio = n_bins_in / max(n_output_bins, 1)
            for t_out in range(n_output_bins):
                t_in_start = int(t_out * ratio)
                t_in_end = min(int((t_out + 1) * ratio), n_bins_in)
                if t_in_start < t_in_end:
                    out[t_out] = spikes[t_in_start:t_in_end].max(axis=0)
            resampled.append(out)

        # Rate normalization: scale so max rate maps to 1.0
        r = resampled[-1]
        max_val = r.max()
        if max_val > 0:
            resampled[-1] = r / max_val

    if self.mode == "concatenate":
        return np.concatenate(resampled, axis=1)

    if self.mode == "sum":
        # Pad smaller modalities and combine
        max_ch = self.n_output
        padded = []
        for r in resampled:
            if r.shape[1] < max_ch:
                pad = np.zeros((r.shape[0], max_ch - r.shape[1]))
                padded.append(np.concatenate([r, pad], axis=1))
            else:
                padded.append(r[:, :max_ch])
        return np.clip(sum(padded), 0, 1)

    if self.mode == "attention":
        weighted = []
        for i, r in enumerate(resampled):
            weighted.append(r * self.attention_weights[i])
        return np.concatenate(weighted, axis=1)

    raise ValueError(f"Unknown mode '{self.mode}'")  # pragma: no cover