Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import torch | |
| import numpy as np | |
| import soundfile as sf | |
| from src.config import BaseConfig | |
| from src.reverb import BaseFDN | |
| from flamo.optimize.trainer import Trainer | |
| from flamo.optimize.dataset import DatasetColorless, load_dataset | |
| from flamo.processor import dsp, system | |
| from flamo.optimize.loss import mse_loss, sparsity_loss | |
| def process_fdn(N, delay_lengths, learning_rate, sparsity_weight, max_epochs): | |
| """ | |
| Process feedback delay network parameters. | |
| Args: | |
| N: Number of delay lines (integer) | |
| delay_lengths: Array of N integer values for delay lengths | |
| learning_rate: Learning rate for optimization | |
| sparsity_weight: Weight for sparsity loss | |
| max_epochs: Maximum number of training epochs | |
| Returns: | |
| A message confirming the inputs | |
| """ | |
| print(f"Number of delay lines (N): {N}") | |
| print(f"Delay lengths: {delay_lengths}") | |
| print(f"Type of delay_lengths: {type(delay_lengths)}") | |
| print(f"Learning rate: {learning_rate}") | |
| print(f"Sparsity weight: {sparsity_weight}") | |
| print(f"Max epochs: {max_epochs}") | |
| try: | |
| # Extract delay length values from the dataframe | |
| if delay_lengths and len(delay_lengths) > 0: | |
| # delay_lengths is a list of rows, extract the first column value from each row | |
| delays = [int(row[0]) for row in delay_lengths if row and len(row) > 0] | |
| # Validate that we have N delay values | |
| if len(delays) != N: | |
| return f"Error: Expected {N} delay lengths, but got {len(delays)}" | |
| result = f"Successfully configured FDN with:\n" | |
| result += f"- Number of delay lines: {N}\n" | |
| result += f"- Delay lengths: {delays}" | |
| # Create the config with FDN parameters | |
| config = BaseConfig.create_with_fdn_params( | |
| N=N, | |
| delay_lengths=delays | |
| ) | |
| # Initialize BaseFDN with proper parameters | |
| model = BaseFDN( | |
| config=config.fdn_config, | |
| nfft=config.nfft, | |
| alias_decay_db=config.fdn_config.alias_decay_db, | |
| device=config.device, | |
| requires_grad=True, | |
| delay_lengths=delays, | |
| output_layer="freq_mag", | |
| ) | |
| dataset = DatasetColorless( | |
| input_shape=(1, config.nfft, 1), | |
| target_shape=(1, config.nfft // 2 + 1, 1), | |
| expand=config.fdn_optim_config.dataset_length, | |
| device=config.device, | |
| ) | |
| train_loader, valid_loader = load_dataset(dataset, batch_size=config.fdn_optim_config.batch_size) | |
| # Initialize training process | |
| trainer = Trainer( | |
| model.shell, | |
| max_epochs=max_epochs, | |
| lr=learning_rate, | |
| device=config.device, | |
| log=False | |
| ) | |
| trainer.register_criterion(mse_loss(nfft=config.nfft, device=config.device), 1) | |
| trainer.register_criterion(sparsity_loss(), sparsity_weight, requires_model=True) | |
| ## ---------------- TRAIN ---------------- ## | |
| # Train the model | |
| print("Starting training...") | |
| trainer.train(train_loader, valid_loader) | |
| est_param = model.get_params() | |
| # Convert parameters to JSON format | |
| # Assuming est_param is a dict or can be converted to one | |
| param_dict = {} | |
| for key, value in est_param.items(): | |
| # Convert tensors to lists for JSON serialization | |
| if hasattr(value, 'cpu'): | |
| param_dict[key] = value.cpu().detach().numpy().tolist() | |
| else: | |
| param_dict[key] = value | |
| # Save to JSON file | |
| output_path = "estimated_parameters.json" | |
| with open(output_path, 'w') as f: | |
| json.dump(param_dict, f, indent=2) | |
| ir = model.shell.get_time_response() | |
| # Convert ir to audio format for Gradio | |
| ir_audio = ir.cpu().detach().numpy() | |
| # Ensure proper shape (1D array) | |
| if ir_audio.ndim > 1: | |
| ir_audio = ir_audio.squeeze() | |
| # Normalize to [-1, 1] range to prevent overflow | |
| max_val = np.abs(ir_audio).max() | |
| if max_val > 0: | |
| ir_audio = ir_audio / max_val | |
| # Get the sample rate from config | |
| sample_rate = getattr(config, 'fs', 48000) | |
| # Save audio to file using soundfile (avoids Gradio's conversion issues) | |
| audio_path = "impulse_response.wav" | |
| sf.write(audio_path, ir_audio, sample_rate) | |
| return result, output_path, audio_path | |
| else: | |
| return "Error: No delay lengths provided", None, None | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error processing inputs: {str(e)}", None, None | |
| demo = gr.Interface( | |
| fn=process_fdn, | |
| inputs=[ | |
| gr.Number(label="N (Number of delay lines)", value=4, precision=0), | |
| gr.Dataframe( | |
| headers=["Delay Length"], | |
| type="array", | |
| col_count=(1, "fixed"), | |
| row_count=(4, "dynamic"), | |
| label="Delay Lengths (N integer values)" | |
| ), | |
| gr.Number(label="Learning Rate", value=0.01, minimum=0.0001, maximum=1.0, step=0.0001), | |
| gr.Number(label="Sparsity Loss Weight", value=1.0, minimum=0.0, maximum=10.0, step=0.1), | |
| gr.Number(label="Max Epochs", value=20, precision=0, minimum=1, maximum=100) | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Output"), | |
| gr.File(label="Estimated Parameters (JSON)"), | |
| gr.Audio(label="Impulse Response", type="numpy") | |
| ], | |
| title="Feedback Delay Network Optimization", | |
| description="Configure your homogeneous feedback delay network by specifying N (number of delay lines) and their corresponding delay lengths. Submit the values to run optimization and obtain estimated parameters and playback the resulting impulse response." | |
| ) | |
| demo.launch(debug=True) |