import gradio as gr import json import torch import numpy as np import soundfile as sf from src.config import BaseConfig from src.reverb import BaseFDN from flamo.optimize.trainer import Trainer from flamo.optimize.dataset import DatasetColorless, load_dataset from flamo.processor import dsp, system from flamo.optimize.loss import mse_loss, sparsity_loss def process_fdn(N, delay_lengths, learning_rate, sparsity_weight, max_epochs): """ Process feedback delay network parameters. Args: N: Number of delay lines (integer) delay_lengths: Array of N integer values for delay lengths learning_rate: Learning rate for optimization sparsity_weight: Weight for sparsity loss max_epochs: Maximum number of training epochs Returns: A message confirming the inputs """ print(f"Number of delay lines (N): {N}") print(f"Delay lengths: {delay_lengths}") print(f"Type of delay_lengths: {type(delay_lengths)}") print(f"Learning rate: {learning_rate}") print(f"Sparsity weight: {sparsity_weight}") print(f"Max epochs: {max_epochs}") try: # Extract delay length values from the dataframe if delay_lengths and len(delay_lengths) > 0: # delay_lengths is a list of rows, extract the first column value from each row delays = [int(row[0]) for row in delay_lengths if row and len(row) > 0] # Validate that we have N delay values if len(delays) != N: return f"Error: Expected {N} delay lengths, but got {len(delays)}" result = f"Successfully configured FDN with:\n" result += f"- Number of delay lines: {N}\n" result += f"- Delay lengths: {delays}" # Create the config with FDN parameters config = BaseConfig.create_with_fdn_params( N=N, delay_lengths=delays ) # Initialize BaseFDN with proper parameters model = BaseFDN( config=config.fdn_config, nfft=config.nfft, alias_decay_db=config.fdn_config.alias_decay_db, device=config.device, requires_grad=True, delay_lengths=delays, output_layer="freq_mag", ) dataset = DatasetColorless( input_shape=(1, config.nfft, 1), target_shape=(1, config.nfft // 2 + 1, 1), expand=config.fdn_optim_config.dataset_length, device=config.device, ) train_loader, valid_loader = load_dataset(dataset, batch_size=config.fdn_optim_config.batch_size) # Initialize training process trainer = Trainer( model.shell, max_epochs=max_epochs, lr=learning_rate, device=config.device, log=False ) trainer.register_criterion(mse_loss(nfft=config.nfft, device=config.device), 1) trainer.register_criterion(sparsity_loss(), sparsity_weight, requires_model=True) ## ---------------- TRAIN ---------------- ## # Train the model print("Starting training...") trainer.train(train_loader, valid_loader) est_param = model.get_params() # Convert parameters to JSON format # Assuming est_param is a dict or can be converted to one param_dict = {} for key, value in est_param.items(): # Convert tensors to lists for JSON serialization if hasattr(value, 'cpu'): param_dict[key] = value.cpu().detach().numpy().tolist() else: param_dict[key] = value # Save to JSON file output_path = "estimated_parameters.json" with open(output_path, 'w') as f: json.dump(param_dict, f, indent=2) ir = model.shell.get_time_response() # Convert ir to audio format for Gradio ir_audio = ir.cpu().detach().numpy() # Ensure proper shape (1D array) if ir_audio.ndim > 1: ir_audio = ir_audio.squeeze() # Normalize to [-1, 1] range to prevent overflow max_val = np.abs(ir_audio).max() if max_val > 0: ir_audio = ir_audio / max_val # Get the sample rate from config sample_rate = getattr(config, 'fs', 48000) # Save audio to file using soundfile (avoids Gradio's conversion issues) audio_path = "impulse_response.wav" sf.write(audio_path, ir_audio, sample_rate) return result, output_path, audio_path else: return "Error: No delay lengths provided", None, None except Exception as e: print(f"Error: {e}") import traceback traceback.print_exc() return f"Error processing inputs: {str(e)}", None, None demo = gr.Interface( fn=process_fdn, inputs=[ gr.Number(label="N (Number of delay lines)", value=4, precision=0), gr.Dataframe( headers=["Delay Length"], type="array", col_count=(1, "fixed"), row_count=(4, "dynamic"), label="Delay Lengths (N integer values)" ), gr.Number(label="Learning Rate", value=0.01, minimum=0.0001, maximum=1.0, step=0.0001), gr.Number(label="Sparsity Loss Weight", value=1.0, minimum=0.0, maximum=10.0, step=0.1), gr.Number(label="Max Epochs", value=20, precision=0, minimum=1, maximum=100) ], outputs=[ gr.Textbox(label="Output"), gr.File(label="Estimated Parameters (JSON)"), gr.Audio(label="Impulse Response", type="numpy") ], title="Feedback Delay Network Optimization", description="Configure your homogeneous feedback delay network by specifying N (number of delay lines) and their corresponding delay lengths. Submit the values to run optimization and obtain estimated parameters and playback the resulting impulse response." ) demo.launch(debug=True)