Vladyslav Moroshan
Apply ruff formatting
96e1a32
import numpy as np
import torch
def generate_spikes(
size: int,
spikes_type: str = "choose_randomly",
spike_intervals: int | None = None,
n_spikes: int | None = None,
to_keep_rate: float = 0.4,
):
spikes = np.zeros(size)
if size < 120:
build_up_points = 1
elif size < 250:
build_up_points = np.random.choice([2, 1], p=[0.3, 0.7])
else:
build_up_points = np.random.choice([3, 2, 1], p=[0.15, 0.45, 0.4])
spike_duration = build_up_points * 2
if spikes_type == "choose_randomly":
spikes_type = np.random.choice(["regular", "patchy", "random"], p=[0.4, 0.5, 0.1])
if spikes_type == "patchy" and size < 64:
spikes_type = "regular"
if spikes_type in ["regular", "patchy"]:
if spike_intervals is None:
upper_bound = np.ceil(spike_duration / 0.05) ## at least 1 spike every 24 periods (120 if 5 spike duration)
lower_bound = np.ceil(spike_duration / 0.15) ## at most 3 spikes every 24 periods
spike_intervals = np.random.randint(lower_bound, upper_bound)
n_spikes = np.ceil(size / spike_intervals)
spike_intervals = np.arange(spike_intervals, size, spike_intervals)
if spikes_type == "patchy":
patch_size = np.random.randint(2, max(n_spikes * 0.7, 3))
to_keep = np.random.randint(np.ceil(patch_size * to_keep_rate), patch_size)
else:
n_spikes = (
n_spikes if n_spikes is not None else np.random.randint(4, min(max(size // (spike_duration * 3), 6), 20))
)
spike_intervals = np.sort(np.random.choice(np.arange(spike_duration, size), size=n_spikes, replace=False))
constant_build_rate = False
if spikes_type in ["regular", "patchy"]:
random_ = np.random.random()
constant_build_rate = True
patch_count = 0
spike_intervals -= 1
for interval in spike_intervals:
interval = np.round(interval).astype(int)
if spikes_type == "patchy":
if patch_count >= patch_size:
patch_count = 0
if patch_count < to_keep:
patch_count += 1
else:
patch_count += 1
continue
if not constant_build_rate:
random_ = np.random.random()
build_up_rate = np.random.uniform(0.5, 2) if random_ < 0.7 else np.random.uniform(2.5, 5)
spike_start = interval - build_up_points + 1
for i in range(build_up_points):
if 0 <= spike_start + i < len(spikes):
spikes[spike_start + i] = build_up_rate * (i + 1)
for i in range(1, build_up_points):
if (interval + i) < len(spikes):
spikes[interval + i] = spikes[interval - i]
# randomly make it positive or negative
spikes += 1
spikes = spikes * np.random.choice([1, -1], 1, p=[0.7, 0.3])
return torch.Tensor(spikes)
def generate_peak_spikes(ts_size, peak_period, spikes_type="regular"):
return generate_spikes(ts_size, spikes_type=spikes_type, spike_intervals=peak_period)