Skip to content

Commit

Permalink
make field non mutable
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjeblick committed Dec 10, 2024
1 parent afc87a3 commit fa78b82
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 27 deletions.
12 changes: 5 additions & 7 deletions kvpress/presses/expected_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@

@dataclass
class ExpectedAttentionPress(ScorerPress):
scorer: ExpectedAttentionScorer = field(default_factory=ExpectedAttentionScorer)
scorer: ExpectedAttentionScorer = field(default_factory=ExpectedAttentionScorer, init=False)
compression_ratio: float = 0.0
n_future_positions: int = 512
n_sink: int = 4
use_covariance: bool = True
use_vnorm: bool = True

def __post_init__(self):
self.scorer = ExpectedAttentionScorer(
n_future_positions=self.n_future_positions,
n_sink=self.n_sink,
use_covariance=self.use_covariance,
use_vnorm=self.use_vnorm,
)
self.scorer.n_future_positions = self.n_future_positions
self.scorer.n_sink = self.n_sink
self.scorer.use_covariance = self.use_covariance
self.scorer.use_vnorm = self.use_vnorm
super().__post_init__()
6 changes: 1 addition & 5 deletions kvpress/presses/knorm_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,5 @@

@dataclass
class KnormPress(ScorerPress):
scorer: KnormScorer = field(default_factory=KnormScorer)
scorer: KnormScorer = field(default_factory=KnormScorer, init=False)
compression_ratio: float = 0.0

def __post_init__(self):
self.scorer = KnormScorer()
super().__post_init__()
2 changes: 1 addition & 1 deletion kvpress/presses/observed_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ObservedAttentionPress(ScorerPress):
It will not return attentions in its output to save memory.
"""

scorer: ObservedAttentionScorer = field(default_factory=ObservedAttentionScorer)
scorer: ObservedAttentionScorer = field(default_factory=ObservedAttentionScorer, init=False)
compression_ratio: float = 0.0
output_attentions: bool = False

Expand Down
6 changes: 1 addition & 5 deletions kvpress/presses/random_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,5 @@

@dataclass
class RandomPress(ScorerPress):
scorer: RandomScorer = field(default_factory=RandomScorer)
scorer: RandomScorer = field(default_factory=RandomScorer, init=False)
compression_ratio: float = 0.0

def __post_init__(self):
self.scorer = RandomScorer()
super().__post_init__()
8 changes: 3 additions & 5 deletions kvpress/presses/snapkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@

@dataclass
class SnapKVPress(ScorerPress):
scorer: SnapKVScorer = field(default_factory=SnapKVScorer)
scorer: SnapKVScorer = field(default_factory=SnapKVScorer, init=False)
compression_ratio: float = 0.0
window_size: int = 64
kernel_size: int = 5

def __post_init__(self):
self.scorer = SnapKVScorer(
window_size=self.window_size,
kernel_size=self.kernel_size,
)
self.scorer.window_size = self.window_size
self.scorer.kernel_size = self.kernel_size
super().__post_init__()
4 changes: 2 additions & 2 deletions kvpress/presses/streaming_llm_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

@dataclass
class StreamingLLMPress(ScorerPress):
scorer: StreamingLLMScorer = field(default_factory=StreamingLLMScorer)
scorer: StreamingLLMScorer = field(default_factory=StreamingLLMScorer, init=False)
compression_ratio: float = 0.0
n_sink: int = 4

def __post_init__(self):
self.scorer = StreamingLLMScorer(n_sink=self.n_sink)
self.scorer.n_sink = self.n_sink
super().__post_init__()
2 changes: 1 addition & 1 deletion kvpress/presses/tova_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@dataclasses.dataclass
class TOVAPress(ScorerPress):
scorer: TOVAScorer = dataclasses.field(default_factory=TOVAScorer)
scorer: TOVAScorer = dataclasses.field(default_factory=TOVAScorer, init=False)
compression_ratio: float = 0.0
window_size: int = 1

Expand Down
2 changes: 1 addition & 1 deletion tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
StreamingLLMPress,
TOVAPress,
)
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.scorers.base_scorer import BaseScorer
from kvpress.presses.think_press import ThinKPress
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401


Expand Down

0 comments on commit fa78b82

Please sign in to comment.