File size: 6,357 Bytes
08eeac9
56e1924
 
 
 
7515bd3
 
 
 
 
 
08eeac9
28f49cd
7515bd3
 
 
 
 
 
28f49cd
 
 
7515bd3
 
 
 
 
 
 
28f49cd
 
 
7515bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
08eeac9
7515bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08eeac9
7515bd3
 
 
 
 
 
 
 
 
 
 
28f49cd
 
7515bd3
 
 
 
28f49cd
7515bd3
 
 
 
 
 
56e1924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7515bd3
56e1924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7515bd3
56e1924
7515bd3
 
 
 
 
56e1924
7515bd3
 
 
 
 
 
 
 
 
 
 
28f49cd
 
 
 
7515bd3
56e1924
 
 
 
 
 
 
7515bd3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
import json
import torch
import numpy as np
import soundfile as sf
from src.config import BaseConfig
from src.reverb import BaseFDN
from flamo.optimize.trainer import Trainer
from flamo.optimize.dataset import DatasetColorless, load_dataset
from flamo.processor import dsp, system
from flamo.optimize.loss import mse_loss, sparsity_loss

def process_fdn(N, delay_lengths, learning_rate, sparsity_weight, max_epochs):
    """
    Process feedback delay network parameters.
    
    Args:
        N: Number of delay lines (integer)
        delay_lengths: Array of N integer values for delay lengths
        learning_rate: Learning rate for optimization
        sparsity_weight: Weight for sparsity loss
        max_epochs: Maximum number of training epochs
    
    Returns:
        A message confirming the inputs
    """
    print(f"Number of delay lines (N): {N}")
    print(f"Delay lengths: {delay_lengths}")
    print(f"Type of delay_lengths: {type(delay_lengths)}")
    print(f"Learning rate: {learning_rate}")
    print(f"Sparsity weight: {sparsity_weight}")
    print(f"Max epochs: {max_epochs}")
    
    try:
        # Extract delay length values from the dataframe
        if delay_lengths and len(delay_lengths) > 0:
            # delay_lengths is a list of rows, extract the first column value from each row
            delays = [int(row[0]) for row in delay_lengths if row and len(row) > 0]
            
            # Validate that we have N delay values
            if len(delays) != N:
                return f"Error: Expected {N} delay lengths, but got {len(delays)}"
            
            result = f"Successfully configured FDN with:\n"
            result += f"- Number of delay lines: {N}\n"
            result += f"- Delay lengths: {delays}"

            # Create the config with FDN parameters
            config = BaseConfig.create_with_fdn_params(
                N=N,
                delay_lengths=delays
            )
            
            # Initialize BaseFDN with proper parameters
            model = BaseFDN(
                config=config.fdn_config,
                nfft=config.nfft,
                alias_decay_db=config.fdn_config.alias_decay_db,
                device=config.device,
                requires_grad=True,
                delay_lengths=delays,
                output_layer="freq_mag",
            )

            dataset = DatasetColorless(
                input_shape=(1, config.nfft, 1),
                target_shape=(1, config.nfft // 2 + 1, 1),
                expand=config.fdn_optim_config.dataset_length,
                device=config.device,
            )
            train_loader, valid_loader = load_dataset(dataset, batch_size=config.fdn_optim_config.batch_size)

            # Initialize training process
            trainer = Trainer(
                model.shell,
                max_epochs=max_epochs,
                lr=learning_rate,
                device=config.device,
                log=False
            )
            trainer.register_criterion(mse_loss(nfft=config.nfft, device=config.device), 1)
            trainer.register_criterion(sparsity_loss(), sparsity_weight, requires_model=True)

            ## ---------------- TRAIN ---------------- ##

            # Train the model
            print("Starting training...")
            trainer.train(train_loader, valid_loader)
            est_param = model.get_params()
            
            # Convert parameters to JSON format
            # Assuming est_param is a dict or can be converted to one
            param_dict = {}
            for key, value in est_param.items():
                # Convert tensors to lists for JSON serialization
                if hasattr(value, 'cpu'):
                    param_dict[key] = value.cpu().detach().numpy().tolist()
                else:
                    param_dict[key] = value
            
            # Save to JSON file
            output_path = "estimated_parameters.json"
            with open(output_path, 'w') as f:
                json.dump(param_dict, f, indent=2)

            ir = model.shell.get_time_response()
            
            # Convert ir to audio format for Gradio
            ir_audio = ir.cpu().detach().numpy()
            
            # Ensure proper shape (1D array)
            if ir_audio.ndim > 1:
                ir_audio = ir_audio.squeeze()
            
            # Normalize to [-1, 1] range to prevent overflow
            max_val = np.abs(ir_audio).max()
            if max_val > 0:
                ir_audio = ir_audio / max_val
            
            # Get the sample rate from config
            sample_rate = getattr(config, 'fs', 48000)
            
            # Save audio to file using soundfile (avoids Gradio's conversion issues)
            audio_path = "impulse_response.wav"
            sf.write(audio_path, ir_audio, sample_rate)
            
            return result, output_path, audio_path
        else:
            return "Error: No delay lengths provided", None, None
            
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        return f"Error processing inputs: {str(e)}", None, None

demo = gr.Interface(
    fn=process_fdn,
    inputs=[
        gr.Number(label="N (Number of delay lines)", value=4, precision=0),
        gr.Dataframe(
            headers=["Delay Length"], 
            type="array", 
            col_count=(1, "fixed"),
            row_count=(4, "dynamic"),
            label="Delay Lengths (N integer values)"
        ),
        gr.Number(label="Learning Rate", value=0.01, minimum=0.0001, maximum=1.0, step=0.0001),
        gr.Number(label="Sparsity Loss Weight", value=1.0, minimum=0.0, maximum=10.0, step=0.1),
        gr.Number(label="Max Epochs", value=20, precision=0, minimum=1, maximum=100)
    ],
    outputs=[
        gr.Textbox(label="Output"),
        gr.File(label="Estimated Parameters (JSON)"),
        gr.Audio(label="Impulse Response", type="numpy")
    ],
    title="Feedback Delay Network Optimization",
    description="Configure your homogeneous feedback delay network by specifying N (number of delay lines) and their corresponding delay lengths. Submit the values to run optimization and obtain estimated parameters and playback the resulting impulse response."
)

demo.launch(debug=True)