Kameshr commited on
Commit
c9f3e76
·
verified ·
1 Parent(s): 551090e

Update modeling_qwen2_custom.py

Browse files
Files changed (1) hide show
  1. modeling_qwen2_custom.py +13 -3
modeling_qwen2_custom.py CHANGED
@@ -1,15 +1,24 @@
1
  from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2Model, Qwen2ForCausalLM
2
  from transformers.configuration_utils import PretrainedConfig
 
3
  import torch
4
  import torch.nn as nn
5
 
6
  class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
7
  def __init__(self, config, layer_idx):
8
  super().__init__(config, layer_idx)
9
- # FIX: Reverted to 1D shape [5120] to match the saved checkpoint
 
 
 
 
 
 
 
 
10
  self.register_buffer(
11
  "resid_bias",
12
- torch.zeros(config.hidden_size),
13
  persistent=True
14
  )
15
 
@@ -23,7 +32,8 @@ class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
23
  bias = self.resid_bias.to(hidden.device).to(hidden.dtype)
24
 
25
  if bias.norm() > 0:
26
- # view(1, 1, -1) safely converts [5120] -> [1, 1, 5120] for broadcasting
 
27
  hidden = hidden + bias.view(1, 1, -1)
28
 
29
  if isinstance(outputs, tuple): outputs = (hidden,) + outputs[1:]
 
1
  from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2Model, Qwen2ForCausalLM
2
  from transformers.configuration_utils import PretrainedConfig
3
+ from transformers import AutoConfig
4
  import torch
5
  import torch.nn as nn
6
 
7
  class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
8
  def __init__(self, config, layer_idx):
9
  super().__init__(config, layer_idx)
10
+
11
+ # --- FIX: HYBRID SHAPE INITIALIZATION ---
12
+ # Layers 28-53 were saved as [1, 5120]. All others as [5120].
13
+ # We switch the shape based on the layer index.
14
+ if 28 <= layer_idx <= 53:
15
+ shape = (1, config.hidden_size)
16
+ else:
17
+ shape = (config.hidden_size,)
18
+
19
  self.register_buffer(
20
  "resid_bias",
21
+ torch.zeros(shape),
22
  persistent=True
23
  )
24
 
 
32
  bias = self.resid_bias.to(hidden.device).to(hidden.dtype)
33
 
34
  if bias.norm() > 0:
35
+ # view(1, 1, -1) fixes the math for everyone.
36
+ # It treats [5120] and [1, 5120] exactly the same.
37
  hidden = hidden + bias.view(1, 1, -1)
38
 
39
  if isinstance(outputs, tuple): outputs = (hidden,) + outputs[1:]