Spaces:
Runtime error
Runtime error
Update modules/model.py
Browse files- modules/model.py +4 -6
modules/model.py
CHANGED
|
@@ -203,14 +203,14 @@ class CrossAttnProcessor(nn.Module):
|
|
| 203 |
k_bucket_size = 1024
|
| 204 |
|
| 205 |
# use flash-attention
|
| 206 |
-
hidden_states =
|
| 207 |
query.contiguous(),
|
| 208 |
key.contiguous(),
|
| 209 |
value.contiguous(),
|
| 210 |
attention_mask,
|
| 211 |
-
|
| 212 |
-
q_bucket_size
|
| 213 |
-
k_bucket_size
|
| 214 |
)
|
| 215 |
hidden_states = hidden_states.to(query.dtype)
|
| 216 |
|
|
@@ -1021,5 +1021,3 @@ class FlashAttentionFunction(Function):
|
|
| 1021 |
dvc.add_(dv_chunk)
|
| 1022 |
|
| 1023 |
return dq, dk, dv, None, None, None, None
|
| 1024 |
-
|
| 1025 |
-
FlashAttn = FlashAttentionFunction()
|
|
|
|
| 203 |
k_bucket_size = 1024
|
| 204 |
|
| 205 |
# use flash-attention
|
| 206 |
+
hidden_states = FlashAttentionFunction.apply(
|
| 207 |
query.contiguous(),
|
| 208 |
key.contiguous(),
|
| 209 |
value.contiguous(),
|
| 210 |
attention_mask,
|
| 211 |
+
False,
|
| 212 |
+
q_bucket_size,
|
| 213 |
+
k_bucket_size,
|
| 214 |
)
|
| 215 |
hidden_states = hidden_states.to(query.dtype)
|
| 216 |
|
|
|
|
| 1021 |
dvc.add_(dv_chunk)
|
| 1022 |
|
| 1023 |
return dq, dk, dv, None, None, None, None
|
|
|
|
|
|