|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
def generate_spikes( |
|
|
size: int, |
|
|
spikes_type: str = "choose_randomly", |
|
|
spike_intervals: int | None = None, |
|
|
n_spikes: int | None = None, |
|
|
to_keep_rate: float = 0.4, |
|
|
): |
|
|
spikes = np.zeros(size) |
|
|
if size < 120: |
|
|
build_up_points = 1 |
|
|
elif size < 250: |
|
|
build_up_points = np.random.choice([2, 1], p=[0.3, 0.7]) |
|
|
else: |
|
|
build_up_points = np.random.choice([3, 2, 1], p=[0.15, 0.45, 0.4]) |
|
|
|
|
|
spike_duration = build_up_points * 2 |
|
|
|
|
|
if spikes_type == "choose_randomly": |
|
|
spikes_type = np.random.choice(["regular", "patchy", "random"], p=[0.4, 0.5, 0.1]) |
|
|
|
|
|
if spikes_type == "patchy" and size < 64: |
|
|
spikes_type = "regular" |
|
|
|
|
|
if spikes_type in ["regular", "patchy"]: |
|
|
if spike_intervals is None: |
|
|
upper_bound = np.ceil(spike_duration / 0.05) |
|
|
lower_bound = np.ceil(spike_duration / 0.15) |
|
|
spike_intervals = np.random.randint(lower_bound, upper_bound) |
|
|
n_spikes = np.ceil(size / spike_intervals) |
|
|
spike_intervals = np.arange(spike_intervals, size, spike_intervals) |
|
|
if spikes_type == "patchy": |
|
|
patch_size = np.random.randint(2, max(n_spikes * 0.7, 3)) |
|
|
to_keep = np.random.randint(np.ceil(patch_size * to_keep_rate), patch_size) |
|
|
else: |
|
|
n_spikes = ( |
|
|
n_spikes if n_spikes is not None else np.random.randint(4, min(max(size // (spike_duration * 3), 6), 20)) |
|
|
) |
|
|
spike_intervals = np.sort(np.random.choice(np.arange(spike_duration, size), size=n_spikes, replace=False)) |
|
|
|
|
|
constant_build_rate = False |
|
|
if spikes_type in ["regular", "patchy"]: |
|
|
random_ = np.random.random() |
|
|
constant_build_rate = True |
|
|
|
|
|
patch_count = 0 |
|
|
spike_intervals -= 1 |
|
|
for interval in spike_intervals: |
|
|
interval = np.round(interval).astype(int) |
|
|
if spikes_type == "patchy": |
|
|
if patch_count >= patch_size: |
|
|
patch_count = 0 |
|
|
if patch_count < to_keep: |
|
|
patch_count += 1 |
|
|
else: |
|
|
patch_count += 1 |
|
|
continue |
|
|
if not constant_build_rate: |
|
|
random_ = np.random.random() |
|
|
build_up_rate = np.random.uniform(0.5, 2) if random_ < 0.7 else np.random.uniform(2.5, 5) |
|
|
|
|
|
spike_start = interval - build_up_points + 1 |
|
|
for i in range(build_up_points): |
|
|
if 0 <= spike_start + i < len(spikes): |
|
|
spikes[spike_start + i] = build_up_rate * (i + 1) |
|
|
|
|
|
for i in range(1, build_up_points): |
|
|
if (interval + i) < len(spikes): |
|
|
spikes[interval + i] = spikes[interval - i] |
|
|
|
|
|
|
|
|
spikes += 1 |
|
|
spikes = spikes * np.random.choice([1, -1], 1, p=[0.7, 0.3]) |
|
|
|
|
|
return torch.Tensor(spikes) |
|
|
|
|
|
|
|
|
def generate_peak_spikes(ts_size, peak_period, spikes_type="regular"): |
|
|
return generate_spikes(ts_size, spikes_type=spikes_type, spike_intervals=peak_period) |
|
|
|