Upload configuration_llada.py with huggingface_hub
Browse files- configuration_llada.py +7 -11
configuration_llada.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
"""
|
| 2 |
LLaDA configuration
|
| 3 |
"""
|
| 4 |
-
from transformers import AutoConfig, PretrainedConfig
|
| 5 |
|
| 6 |
-
from enum import Enum
|
| 7 |
-
from os import PathLike
|
| 8 |
-
from typing import Union
|
| 9 |
from dataclasses import asdict, dataclass, field
|
|
|
|
| 10 |
from glob import glob
|
|
|
|
| 11 |
from pathlib import Path
|
| 12 |
from typing import (
|
| 13 |
Any,
|
|
@@ -22,6 +20,7 @@ from typing import (
|
|
| 22 |
cast,
|
| 23 |
)
|
| 24 |
|
|
|
|
| 25 |
|
| 26 |
__all__ = [
|
| 27 |
"ActivationType",
|
|
@@ -127,7 +126,7 @@ class InitFnType(StrEnum):
|
|
| 127 |
|
| 128 |
|
| 129 |
@dataclass
|
| 130 |
-
class ModelConfig
|
| 131 |
"""
|
| 132 |
LLaDA (model) configuration.
|
| 133 |
"""
|
|
@@ -383,6 +382,7 @@ class ModelConfig():
|
|
| 383 |
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
| 384 |
)
|
| 385 |
|
|
|
|
| 386 |
class ActivationCheckpointingStrategy(StrEnum):
|
| 387 |
whole_layer = "whole_layer"
|
| 388 |
"""
|
|
@@ -403,7 +403,7 @@ class ActivationCheckpointingStrategy(StrEnum):
|
|
| 403 |
"""
|
| 404 |
Checkpoint one in four transformer layers.
|
| 405 |
"""
|
| 406 |
-
|
| 407 |
two_in_three = "two_in_three"
|
| 408 |
"""
|
| 409 |
Checkpoint two out of every three transformer layers.
|
|
@@ -439,11 +439,7 @@ class LLaDAConfig(PretrainedConfig):
|
|
| 439 |
all_kwargs = model_config.__dict__
|
| 440 |
all_kwargs.update(kwargs)
|
| 441 |
all_kwargs.update({"use_cache": use_cache})
|
| 442 |
-
all_kwargs.update(
|
| 443 |
-
{
|
| 444 |
-
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
|
| 445 |
-
}
|
| 446 |
-
)
|
| 447 |
super().__init__(**all_kwargs)
|
| 448 |
|
| 449 |
@property
|
|
|
|
| 1 |
"""
|
| 2 |
LLaDA configuration
|
| 3 |
"""
|
|
|
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
from dataclasses import asdict, dataclass, field
|
| 6 |
+
from enum import Enum
|
| 7 |
from glob import glob
|
| 8 |
+
from os import PathLike
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import (
|
| 11 |
Any,
|
|
|
|
| 20 |
cast,
|
| 21 |
)
|
| 22 |
|
| 23 |
+
from transformers import AutoConfig, PretrainedConfig
|
| 24 |
|
| 25 |
__all__ = [
|
| 26 |
"ActivationType",
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
@dataclass
|
| 129 |
+
class ModelConfig:
|
| 130 |
"""
|
| 131 |
LLaDA (model) configuration.
|
| 132 |
"""
|
|
|
|
| 382 |
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
| 383 |
)
|
| 384 |
|
| 385 |
+
|
| 386 |
class ActivationCheckpointingStrategy(StrEnum):
|
| 387 |
whole_layer = "whole_layer"
|
| 388 |
"""
|
|
|
|
| 403 |
"""
|
| 404 |
Checkpoint one in four transformer layers.
|
| 405 |
"""
|
| 406 |
+
|
| 407 |
two_in_three = "two_in_three"
|
| 408 |
"""
|
| 409 |
Checkpoint two out of every three transformer layers.
|
|
|
|
| 439 |
all_kwargs = model_config.__dict__
|
| 440 |
all_kwargs.update(kwargs)
|
| 441 |
all_kwargs.update({"use_cache": use_cache})
|
| 442 |
+
all_kwargs.update({"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
super().__init__(**all_kwargs)
|
| 444 |
|
| 445 |
@property
|