Gloria Dal Santo
Add more input parameters
28f49cd
import gradio as gr
import json
import torch
import numpy as np
import soundfile as sf
from src.config import BaseConfig
from src.reverb import BaseFDN
from flamo.optimize.trainer import Trainer
from flamo.optimize.dataset import DatasetColorless, load_dataset
from flamo.processor import dsp, system
from flamo.optimize.loss import mse_loss, sparsity_loss
def process_fdn(N, delay_lengths, learning_rate, sparsity_weight, max_epochs):
"""
Process feedback delay network parameters.
Args:
N: Number of delay lines (integer)
delay_lengths: Array of N integer values for delay lengths
learning_rate: Learning rate for optimization
sparsity_weight: Weight for sparsity loss
max_epochs: Maximum number of training epochs
Returns:
A message confirming the inputs
"""
print(f"Number of delay lines (N): {N}")
print(f"Delay lengths: {delay_lengths}")
print(f"Type of delay_lengths: {type(delay_lengths)}")
print(f"Learning rate: {learning_rate}")
print(f"Sparsity weight: {sparsity_weight}")
print(f"Max epochs: {max_epochs}")
try:
# Extract delay length values from the dataframe
if delay_lengths and len(delay_lengths) > 0:
# delay_lengths is a list of rows, extract the first column value from each row
delays = [int(row[0]) for row in delay_lengths if row and len(row) > 0]
# Validate that we have N delay values
if len(delays) != N:
return f"Error: Expected {N} delay lengths, but got {len(delays)}"
result = f"Successfully configured FDN with:\n"
result += f"- Number of delay lines: {N}\n"
result += f"- Delay lengths: {delays}"
# Create the config with FDN parameters
config = BaseConfig.create_with_fdn_params(
N=N,
delay_lengths=delays
)
# Initialize BaseFDN with proper parameters
model = BaseFDN(
config=config.fdn_config,
nfft=config.nfft,
alias_decay_db=config.fdn_config.alias_decay_db,
device=config.device,
requires_grad=True,
delay_lengths=delays,
output_layer="freq_mag",
)
dataset = DatasetColorless(
input_shape=(1, config.nfft, 1),
target_shape=(1, config.nfft // 2 + 1, 1),
expand=config.fdn_optim_config.dataset_length,
device=config.device,
)
train_loader, valid_loader = load_dataset(dataset, batch_size=config.fdn_optim_config.batch_size)
# Initialize training process
trainer = Trainer(
model.shell,
max_epochs=max_epochs,
lr=learning_rate,
device=config.device,
log=False
)
trainer.register_criterion(mse_loss(nfft=config.nfft, device=config.device), 1)
trainer.register_criterion(sparsity_loss(), sparsity_weight, requires_model=True)
## ---------------- TRAIN ---------------- ##
# Train the model
print("Starting training...")
trainer.train(train_loader, valid_loader)
est_param = model.get_params()
# Convert parameters to JSON format
# Assuming est_param is a dict or can be converted to one
param_dict = {}
for key, value in est_param.items():
# Convert tensors to lists for JSON serialization
if hasattr(value, 'cpu'):
param_dict[key] = value.cpu().detach().numpy().tolist()
else:
param_dict[key] = value
# Save to JSON file
output_path = "estimated_parameters.json"
with open(output_path, 'w') as f:
json.dump(param_dict, f, indent=2)
ir = model.shell.get_time_response()
# Convert ir to audio format for Gradio
ir_audio = ir.cpu().detach().numpy()
# Ensure proper shape (1D array)
if ir_audio.ndim > 1:
ir_audio = ir_audio.squeeze()
# Normalize to [-1, 1] range to prevent overflow
max_val = np.abs(ir_audio).max()
if max_val > 0:
ir_audio = ir_audio / max_val
# Get the sample rate from config
sample_rate = getattr(config, 'fs', 48000)
# Save audio to file using soundfile (avoids Gradio's conversion issues)
audio_path = "impulse_response.wav"
sf.write(audio_path, ir_audio, sample_rate)
return result, output_path, audio_path
else:
return "Error: No delay lengths provided", None, None
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
return f"Error processing inputs: {str(e)}", None, None
demo = gr.Interface(
fn=process_fdn,
inputs=[
gr.Number(label="N (Number of delay lines)", value=4, precision=0),
gr.Dataframe(
headers=["Delay Length"],
type="array",
col_count=(1, "fixed"),
row_count=(4, "dynamic"),
label="Delay Lengths (N integer values)"
),
gr.Number(label="Learning Rate", value=0.01, minimum=0.0001, maximum=1.0, step=0.0001),
gr.Number(label="Sparsity Loss Weight", value=1.0, minimum=0.0, maximum=10.0, step=0.1),
gr.Number(label="Max Epochs", value=20, precision=0, minimum=1, maximum=100)
],
outputs=[
gr.Textbox(label="Output"),
gr.File(label="Estimated Parameters (JSON)"),
gr.Audio(label="Impulse Response", type="numpy")
],
title="Feedback Delay Network Optimization",
description="Configure your homogeneous feedback delay network by specifying N (number of delay lines) and their corresponding delay lengths. Submit the values to run optimization and obtain estimated parameters and playback the resulting impulse response."
)
demo.launch(debug=True)