altpuppet commited on
Commit
52beecd
·
1 Parent(s): 0a3e8eb

Fix syntax error in gated_deltaproduct.py and add matplotlib dependency

Browse files
requirements.txt CHANGED
@@ -18,3 +18,4 @@ python-dateutil>=2.8.0
18
  pytz>=2021.1
19
  PyYAML>=5.4.1
20
  flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main
 
 
18
  pytz>=2021.1
19
  PyYAML>=5.4.1
20
  flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main
21
+ matplotlib>=3.5.0
src/models/gated_deltaproduct/gated_deltaproduct.py CHANGED
@@ -74,11 +74,9 @@ class GatedDeltaProduct(nn.Module):
74
  # Consistency check: Ensure expand_v produces integer values
75
  if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5):
76
  raise ValueError(
77
- f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "(
78
- f"Resulting value_dim would be "
79
- f"{self.num_v_heads * self.head_dim * expand_v}, "
80
- "which is invalid for nn.Linear."
81
- )
82
  )
83
  if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
84
  raise ValueError(f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.")
 
74
  # Consistency check: Ensure expand_v produces integer values
75
  if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5):
76
  raise ValueError(
77
+ f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
78
+ f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, "
79
+ "which is invalid for nn.Linear."
 
 
80
  )
81
  if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
82
  raise ValueError(f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.")