JakeOh commited on
Commit
f6b5ce6
·
verified ·
1 Parent(s): 2d6e01c

Upload configuration_llada.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_llada.py +7 -11
configuration_llada.py CHANGED
@@ -1,13 +1,11 @@
1
  """
2
  LLaDA configuration
3
  """
4
- from transformers import AutoConfig, PretrainedConfig
5
 
6
- from enum import Enum
7
- from os import PathLike
8
- from typing import Union
9
  from dataclasses import asdict, dataclass, field
 
10
  from glob import glob
 
11
  from pathlib import Path
12
  from typing import (
13
  Any,
@@ -22,6 +20,7 @@ from typing import (
22
  cast,
23
  )
24
 
 
25
 
26
  __all__ = [
27
  "ActivationType",
@@ -127,7 +126,7 @@ class InitFnType(StrEnum):
127
 
128
 
129
  @dataclass
130
- class ModelConfig():
131
  """
132
  LLaDA (model) configuration.
133
  """
@@ -383,6 +382,7 @@ class ModelConfig():
383
  "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
384
  )
385
 
 
386
  class ActivationCheckpointingStrategy(StrEnum):
387
  whole_layer = "whole_layer"
388
  """
@@ -403,7 +403,7 @@ class ActivationCheckpointingStrategy(StrEnum):
403
  """
404
  Checkpoint one in four transformer layers.
405
  """
406
-
407
  two_in_three = "two_in_three"
408
  """
409
  Checkpoint two out of every three transformer layers.
@@ -439,11 +439,7 @@ class LLaDAConfig(PretrainedConfig):
439
  all_kwargs = model_config.__dict__
440
  all_kwargs.update(kwargs)
441
  all_kwargs.update({"use_cache": use_cache})
442
- all_kwargs.update(
443
- {
444
- "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
445
- }
446
- )
447
  super().__init__(**all_kwargs)
448
 
449
  @property
 
1
  """
2
  LLaDA configuration
3
  """
 
4
 
 
 
 
5
  from dataclasses import asdict, dataclass, field
6
+ from enum import Enum
7
  from glob import glob
8
+ from os import PathLike
9
  from pathlib import Path
10
  from typing import (
11
  Any,
 
20
  cast,
21
  )
22
 
23
+ from transformers import AutoConfig, PretrainedConfig
24
 
25
  __all__ = [
26
  "ActivationType",
 
126
 
127
 
128
  @dataclass
129
+ class ModelConfig:
130
  """
131
  LLaDA (model) configuration.
132
  """
 
382
  "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
383
  )
384
 
385
+
386
  class ActivationCheckpointingStrategy(StrEnum):
387
  whole_layer = "whole_layer"
388
  """
 
403
  """
404
  Checkpoint one in four transformer layers.
405
  """
406
+
407
  two_in_three = "two_in_three"
408
  """
409
  Checkpoint two out of every three transformer layers.
 
439
  all_kwargs = model_config.__dict__
440
  all_kwargs.update(kwargs)
441
  all_kwargs.update({"use_cache": use_cache})
442
+ all_kwargs.update({"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])})
 
 
 
 
443
  super().__init__(**all_kwargs)
444
 
445
  @property