diff --git a/docs/source/api/components.rst b/docs/source/api/components.rst index 5ac2f9f..0a82082 100644 --- a/docs/source/api/components.rst +++ b/docs/source/api/components.rst @@ -8,22 +8,27 @@ Modular torch.nn.Module components for building custom architectures. Text Embedding -------------- -TextEmbedder -~~~~~~~~~~~~ +Text embedding is split into two composable stages: -Embeds text tokens with optional self-attention. +1. **TokenEmbedder** — maps each token to a dense vector (with optional self-attention). Output: ``(batch, seq_len, embedding_dim)``. +2. **SentenceEmbedder** — aggregates token vectors into a sentence embedding. Output: ``(batch, embedding_dim)`` or ``(batch, num_classes, embedding_dim)`` with label attention. -.. autoclass:: torchTextClassifiers.model.components.text_embedder.TextEmbedder +TokenEmbedder +~~~~~~~~~~~~~ + +Embeds tokenized text with optional self-attention. + +.. autoclass:: torchTextClassifiers.model.components.text_embedder.TokenEmbedder :members: :undoc-members: :show-inheritance: -TextEmbedderConfig -~~~~~~~~~~~~~~~~~~ +TokenEmbedderConfig +~~~~~~~~~~~~~~~~~~~ -Configuration for TextEmbedder. +Configuration for TokenEmbedder. -.. autoclass:: torchTextClassifiers.model.components.text_embedder.TextEmbedderConfig +.. autoclass:: torchTextClassifiers.model.components.text_embedder.TokenEmbedderConfig :members: :undoc-members: :show-inheritance: @@ -32,31 +37,90 @@ Example: .. code-block:: python - from torchTextClassifiers.model.components import TextEmbedder, TextEmbedderConfig + from torchTextClassifiers.model.components import ( + TokenEmbedder, TokenEmbedderConfig, AttentionConfig, + ) - # Simple text embedder - config = TextEmbedderConfig( + # Simple token embedder (no self-attention) + config = TokenEmbedderConfig( vocab_size=5000, embedding_dim=128, - attention_config=None + padding_idx=0, ) - embedder = TextEmbedder(config) + token_embedder = TokenEmbedder(config) + out = token_embedder(input_ids, attention_mask) + # out["token_embeddings"]: (batch, seq_len, 128) # With self-attention - from torchTextClassifiers.model.components import AttentionConfig - attention_config = AttentionConfig( - n_embd=128, + n_layers=2, n_head=4, - n_layer=2, - dropout=0.1 + n_kv_head=4, + positional_encoding=False, ) - config = TextEmbedderConfig( + config = TokenEmbedderConfig( vocab_size=5000, embedding_dim=128, - attention_config=attention_config + padding_idx=0, + attention_config=attention_config, + ) + token_embedder = TokenEmbedder(config) + +SentenceEmbedder +~~~~~~~~~~~~~~~~ + +Aggregates per-token embeddings into a single sentence embedding. + +.. autoclass:: torchTextClassifiers.model.components.text_embedder.SentenceEmbedder + :members: + :undoc-members: + :show-inheritance: + +SentenceEmbedderConfig +~~~~~~~~~~~~~~~~~~~~~~ + +Configuration for SentenceEmbedder. + +.. autoclass:: torchTextClassifiers.model.components.text_embedder.SentenceEmbedderConfig + :members: + :undoc-members: + :show-inheritance: + +LabelAttentionConfig +~~~~~~~~~~~~~~~~~~~~ + +Configuration for the label-attention aggregation mode. + +.. autoclass:: torchTextClassifiers.model.components.text_embedder.LabelAttentionConfig + :members: + :undoc-members: + :show-inheritance: + +Example: + +.. code-block:: python + + from torchTextClassifiers.model.components import ( + SentenceEmbedder, SentenceEmbedderConfig, + LabelAttentionConfig, ) - embedder = TextEmbedder(config) + + # Mean-pooling (default) + sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean")) + out = sentence_embedder(token_embeddings, attention_mask) + # out["sentence_embedding"]: (batch, 128) + + # Label attention — one embedding per class + sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig( + aggregation_method=None, + label_attention_config=LabelAttentionConfig( + n_head=4, + num_classes=6, + embedding_dim=128, + ), + )) + out = sentence_embedder(token_embeddings, attention_mask) + # out["sentence_embedding"]: (batch, num_classes, 128) Categorical Features -------------------- @@ -246,22 +310,31 @@ Components can be composed to create custom architectures: .. code-block:: python + import torch import torch.nn as nn from torchTextClassifiers.model.components import ( - TextEmbedder, CategoricalVariableNet, ClassificationHead + TokenEmbedder, TokenEmbedderConfig, + SentenceEmbedder, SentenceEmbedderConfig, + CategoricalVariableNet, ClassificationHead, ) class CustomModel(nn.Module): def __init__(self): super().__init__() - self.text_embedder = TextEmbedder(text_config) + self.token_embedder = TokenEmbedder(TokenEmbedderConfig( + vocab_size=5000, embedding_dim=128, padding_idx=0, + )) + self.sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig()) self.cat_net = CategoricalVariableNet(...) self.head = ClassificationHead(...) - def forward(self, input_ids, categorical_data): - text_features = self.text_embedder(input_ids) + def forward(self, input_ids, attention_mask, categorical_data): + token_out = self.token_embedder(input_ids, attention_mask) + sent_out = self.sentence_embedder( + token_out["token_embeddings"], token_out["attention_mask"] + ) cat_features = self.cat_net(categorical_data) - combined = torch.cat([text_features, cat_features], dim=1) + combined = torch.cat([sent_out["sentence_embedding"], cat_features], dim=1) return self.head(combined) See Also @@ -270,4 +343,3 @@ See Also * :doc:`model` - How components are used in models * :doc:`../architecture/overview` - Architecture explanation * :doc:`configs` - ModelConfig for component configuration - diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 5a7b468..14d44c5 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -11,7 +11,7 @@ The API is organized into several modules: * :doc:`wrapper` - High-level torchTextClassifiers wrapper class * :doc:`configs` - Configuration classes (ModelConfig, TrainingConfig) * :doc:`tokenizers` - Text tokenization (NGram, WordPiece, HuggingFace) -* :doc:`components` - Model components (TextEmbedder, CategoricalVariableNet, etc.) +* :doc:`components` - Model components (TokenEmbedder, SentenceEmbedder, CategoricalVariableNet, etc.) * :doc:`model` - Core PyTorch models * :doc:`dataset` - Dataset classes for data loading @@ -30,7 +30,8 @@ Most Used Classes Architecture Components ~~~~~~~~~~~~~~~~~~~~~~~ -* :class:`torchTextClassifiers.model.components.TextEmbedder` - Text embedding layer +* :class:`torchTextClassifiers.model.components.text_embedder.TokenEmbedder` - Token embedding layer +* :class:`torchTextClassifiers.model.components.text_embedder.SentenceEmbedder` - Sentence aggregation layer * :class:`torchTextClassifiers.model.components.CategoricalVariableNet` - Categorical features * :class:`torchTextClassifiers.model.components.ClassificationHead` - Classification layer * :class:`torchTextClassifiers.model.components.Attention.AttentionConfig` - Attention configuration diff --git a/docs/source/api/model.rst b/docs/source/api/model.rst index e31eac0..16690a4 100644 --- a/docs/source/api/model.rst +++ b/docs/source/api/model.rst @@ -20,11 +20,12 @@ Core PyTorch nn.Module combining all components. **Architecture:** - The model combines three main components: + The model combines four main components: - 1. **TextEmbedder**: Converts tokens to embeddings - 2. **CategoricalVariableNet** (optional): Handles categorical features - 3. **ClassificationHead**: Produces class logits + 1. **TokenEmbedder**: Maps each token to a dense vector (with optional self-attention) + 2. **SentenceEmbedder**: Aggregates token vectors into a sentence representation + 3. **CategoricalVariableNet** (optional): Handles categorical features + 4. **ClassificationHead**: Produces class logits Example: @@ -32,37 +33,40 @@ Example: from torchTextClassifiers.model import TextClassificationModel from torchTextClassifiers.model.components import ( - TextEmbedder, TextEmbedderConfig, - CategoricalVariableNet, CategoricalForwardType, - ClassificationHead + TokenEmbedder, TokenEmbedderConfig, + SentenceEmbedder, SentenceEmbedderConfig, + CategoricalVariableNet, + ClassificationHead, ) # Create components - text_embedder = TextEmbedder(TextEmbedderConfig( + token_embedder = TokenEmbedder(TokenEmbedderConfig( vocab_size=5000, - embedding_dim=128 + embedding_dim=128, + padding_idx=0, )) + sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean")) cat_net = CategoricalVariableNet( - vocabulary_sizes=[10, 20], - embedding_dims=[8, 16], - forward_type=CategoricalForwardType.AVERAGE_AND_CONCAT + categorical_vocabulary_sizes=[10, 20], + categorical_embedding_dims=[8, 16], ) classification_head = ClassificationHead( input_dim=128 + 24, # text_dim + cat_dim - num_classes=5 + num_classes=5, ) # Combine into model model = TextClassificationModel( - text_embedder=text_embedder, + token_embedder=token_embedder, + sentence_embedder=sentence_embedder, categorical_variable_net=cat_net, - classification_head=classification_head + classification_head=classification_head, ) # Forward pass - logits = model(input_ids, categorical_data) + logits = model(input_ids, attention_mask, categorical_data) PyTorch Lightning Module ------------------------- diff --git a/docs/source/architecture/overview.md b/docs/source/architecture/overview.md index 3de20a6..737fa28 100644 --- a/docs/source/architecture/overview.md +++ b/docs/source/architecture/overview.md @@ -157,50 +157,53 @@ output = tokenizer(["Hello world!", "Text classification"]) # output.attention_mask: Attention mask (batch_size, seq_len) ``` -## Component 2: Text Embedder +## Component 2: Text Embedding Pipeline (TokenEmbedder + SentenceEmbedder) -**Purpose:** Convert tokens into dense, semantic embeddings that capture meaning. +**Purpose:** Convert tokens into a single dense vector per sample (or one per class with label attention). -### Basic Text Embedding +Text embedding is split into two distinct, composable stages: + +- **`TokenEmbedder`**: maps each input token to a dense vector, with optional self-attention. Output shape: `(batch, seq_len, embedding_dim)`. +- **`SentenceEmbedder`**: aggregates per-token vectors into a fixed-size sentence representation. Output shape: `(batch, embedding_dim)`, or `(batch, num_classes, embedding_dim)` when label attention is enabled. + +### Stage 1 — TokenEmbedder ```python -from torchTextClassifiers.model.components import TextEmbedder, TextEmbedderConfig +from torchTextClassifiers.model.components import TokenEmbedder, TokenEmbedderConfig -config = TextEmbedderConfig( +config = TokenEmbedderConfig( vocab_size=5000, embedding_dim=128, + padding_idx=0, ) -embedder = TextEmbedder(config) +token_embedder = TokenEmbedder(config) -# Forward pass -text_features = embedder(token_ids) # Shape: (batch_size, 128) +# Forward pass — returns a dict +out = token_embedder(input_ids, attention_mask) +# out["token_embeddings"]: (batch_size, seq_len, 128) ``` -**How it works:** -1. Looks up embedding for each token -2. Averages embeddings across the sequence -3. Produces a fixed-size vector per sample - -### With Self-Attention (Optional) +#### With Self-Attention (Optional) Add transformer-style self-attention for better contextual understanding: ```python -from torchTextClassifiers.model.components import AttentionConfig +from torchTextClassifiers.model.components import AttentionConfig, TokenEmbedder, TokenEmbedderConfig attention_config = AttentionConfig( - n_embd=128, - n_head=4, # Number of attention heads - n_layer=2, # Number of transformer blocks - dropout=0.1, + n_layers=2, + n_head=4, + n_kv_head=4, + positional_encoding=False, ) -config = TextEmbedderConfig( +config = TokenEmbedderConfig( vocab_size=5000, embedding_dim=128, - attention_config=attention_config, # Add attention + padding_idx=0, + attention_config=attention_config, ) -embedder = TextEmbedder(config) +token_embedder = TokenEmbedder(config) ``` **When to use attention:** @@ -211,12 +214,36 @@ embedder = TextEmbedder(config) **Configuration:** - `embedding_dim`: Size of embedding vectors (e.g., 64, 128, 256) - `n_head`: Number of attention heads (typically 4, 8, or 16) -- `n_layer`: Depth of transformer (start with 2-3) +- `n_layers`: Depth of transformer (start with 2-3) + +### Stage 2 — SentenceEmbedder + +`SentenceEmbedder` collapses the `(batch, seq_len, dim)` token matrix into a sentence vector using one of several aggregation strategies: + +```python +from torchTextClassifiers.model.components import SentenceEmbedder, SentenceEmbedderConfig + +# Mean-pooling (default) +sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean")) + +out = sentence_embedder(token_embeddings, attention_mask) +# out["sentence_embedding"]: (batch_size, 128) +``` + +Available aggregation methods: + +| `aggregation_method` | Description | +|----------------------|-------------| +| `"mean"` (default) | Masked mean of token embeddings | +| `"first"` | First token (e.g. `[CLS]` for BERT-style models) | +| `"last"` | Last non-padding token (GPT-style) | +| `None` | Use label attention (see below) | ### With Label Attention (Optional Explainability Layer) -Label attention replaces mean-pooling with a **cross-attention mechanism** where each -class has a learnable embedding that attends over the token sequence: +Setting `aggregation_method=None` and providing a `LabelAttentionConfig` replaces +mean-pooling with a **cross-attention mechanism** where each class has a learnable +embedding that attends over the token sequence: ``` Token embeddings (batch, seq_len, d) @@ -228,7 +255,7 @@ ClassificationHead (d → 1) ← shared, applied per class Logits (batch, num_classes) ``` -Enable it by setting `n_heads_label_attention` in `ModelConfig`: +Enable it by setting `n_heads_label_attention` in `ModelConfig` (high-level API): ```python model_config = ModelConfig( @@ -238,11 +265,28 @@ model_config = ModelConfig( ) ``` +Or directly with the low-level components: + +```python +from torchTextClassifiers.model.components import ( + LabelAttentionConfig, SentenceEmbedder, SentenceEmbedderConfig, +) + +sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig( + aggregation_method=None, + label_attention_config=LabelAttentionConfig( + n_head=4, + num_classes=6, + embedding_dim=96, + ), +)) +``` + **Benefits:** - Free explainability at inference time (`explain_with_label_attention=True` in `predict`) - The returned attention matrix `(batch, n_head, num_classes, seq_len)` shows which tokens each class focuses on -- Can be combined with self-attention (`attention_config`) +- Can be combined with self-attention in `TokenEmbedder` **Constraint:** `embedding_dim` must be divisible by `n_heads_label_attention`. @@ -387,15 +431,25 @@ The framework automatically combines all components: ```python from torchTextClassifiers.model import TextClassificationModel +from torchTextClassifiers.model.components import ( + TokenEmbedder, TokenEmbedderConfig, + SentenceEmbedder, SentenceEmbedderConfig, +) + +token_embedder = TokenEmbedder(TokenEmbedderConfig( + vocab_size=5000, embedding_dim=128, padding_idx=0, +)) +sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean")) model = TextClassificationModel( - text_embedder=text_embedder, + token_embedder=token_embedder, + sentence_embedder=sentence_embedder, categorical_variable_net=cat_handler, # Optional classification_head=head, ) # Forward pass -logits = model(token_ids, categorical_data) +logits = model(input_ids, attention_mask, categorical_data) ``` ## Usage Examples @@ -490,19 +544,28 @@ For maximum flexibility, compose components manually: ```python from torch import nn -from torchTextClassifiers.model.components import TextEmbedder, ClassificationHead +from torchTextClassifiers.model.components import ( + TokenEmbedder, TokenEmbedderConfig, + SentenceEmbedder, SentenceEmbedderConfig, + ClassificationHead, +) -# Create custom model class CustomClassifier(nn.Module): def __init__(self): super().__init__() - self.text_embedder = TextEmbedder(text_config) + self.token_embedder = TokenEmbedder(TokenEmbedderConfig( + vocab_size=5000, embedding_dim=128, padding_idx=0, + )) + self.sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig()) self.custom_layer = nn.Linear(128, 64) self.head = ClassificationHead(64, num_classes) - def forward(self, input_ids): - text_features = self.text_embedder(input_ids) - custom_features = self.custom_layer(text_features) + def forward(self, input_ids, attention_mask): + token_out = self.token_embedder(input_ids, attention_mask) + sent_out = self.sentence_embedder( + token_out["token_embeddings"], token_out["attention_mask"] + ) + custom_features = self.custom_layer(sent_out["sentence_embedding"]) return self.head(custom_features) ``` @@ -544,12 +607,18 @@ All components are standard `torch.nn.Module` objects: ```python # All components work with standard PyTorch -isinstance(text_embedder, nn.Module) # True +isinstance(token_embedder, nn.Module) # True +isinstance(sentence_embedder, nn.Module) # True isinstance(cat_handler, nn.Module) # True isinstance(head, nn.Module) # True # Use in any PyTorch code -model = TextClassificationModel(text_embedder, cat_handler, head) +model = TextClassificationModel( + token_embedder=token_embedder, + sentence_embedder=sentence_embedder, + categorical_variable_net=cat_handler, + classification_head=head, +) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Standard PyTorch training loop @@ -597,8 +666,9 @@ Each component is independent and can be used separately: # Use just the tokenizer tokenizer = NGramTokenizer() -# Use just the embedder -embedder = TextEmbedder(config) +# Use just the token embedder or sentence embedder +token_embedder = TokenEmbedder(TokenEmbedderConfig(...)) +sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig()) # Use just the classifier head head = ClassificationHead(input_dim, num_classes) @@ -610,13 +680,20 @@ Mix and match components for your use case: ```python # Text only -model = TextClassificationModel(text_embedder, None, head) +model = TextClassificationModel( + token_embedder=token_embedder, + sentence_embedder=sentence_embedder, + categorical_variable_net=None, + classification_head=head, +) # Text + categorical -model = TextClassificationModel(text_embedder, cat_handler, head) - -# Custom combination -model = MyCustomModel(text_embedder, my_layer, head) +model = TextClassificationModel( + token_embedder=token_embedder, + sentence_embedder=sentence_embedder, + categorical_variable_net=cat_handler, + classification_head=head, +) ``` ### Simplicity @@ -642,18 +719,19 @@ model_config = ModelConfig( Easy to add custom components: ```python -class MyCustomEmbedder(nn.Module): +class MyCustomTokenEmbedder(nn.Module): def __init__(self): super().__init__() # Your custom implementation - def forward(self, input_ids): - # Your custom forward pass - return embeddings + def forward(self, input_ids, attention_mask): + # Your custom forward pass — must return a dict with "token_embeddings" + return {"token_embeddings": embeddings, "attention_mask": attention_mask} # Use with existing components model = TextClassificationModel( - text_embedder=MyCustomEmbedder(), + token_embedder=MyCustomTokenEmbedder(), + sentence_embedder=SentenceEmbedder(SentenceEmbedderConfig()), classification_head=head, ) ``` @@ -696,9 +774,10 @@ torchTextClassifiers provides a **component-based pipeline** for text classifica 0. **ValueEncoder** (optional) → Encodes raw string inputs; decodes predictions back to original labels 1. **Tokenizer** → Converts text to tokens -2. **Text Embedder** → Creates semantic embeddings (with optional self-attention and/or label attention) -3. **Categorical Handler** (optional) → Processes additional categorical features -4. **Classification Head** → Produces predictions +2. **TokenEmbedder** → Embeds tokens into dense vectors (with optional self-attention) → `(batch, seq_len, dim)` +3. **SentenceEmbedder** → Aggregates token vectors into a sentence embedding (mean / first / last / label attention) → `(batch, dim)` or `(batch, num_classes, dim)` +4. **Categorical Handler** (optional) → Processes additional categorical features +5. **Classification Head** → Produces predictions **Key Benefits:** - Clear data flow through intuitive components diff --git a/docs/source/index.md b/docs/source/index.md index bba98fc..4430708 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -166,7 +166,7 @@ Work with a consistent, simple API whether you're doing binary, multiclass, or m ### Flexible Components -All components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) are standard `torch.nn.Module` objects. Mix and match them or create your own custom components. +All components (`TokenEmbedder`, `SentenceEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) are standard `torch.nn.Module` objects. Mix and match them or create your own custom components. ### Production Ready diff --git a/tests/conftest.py b/tests/conftest.py index 4023570..ca8bc65 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,6 +63,7 @@ def mock_tokenizer(): tokenizer = Mock() tokenizer.vocab_size = 1000 tokenizer.padding_idx = 1 + tokenizer.output_vectorized = False tokenizer.tokenize = Mock( return_value={ "input_ids": np.array([[1, 2, 3], [4, 5, 6]]), diff --git a/tests/test_components.py b/tests/test_components.py new file mode 100644 index 0000000..796135c --- /dev/null +++ b/tests/test_components.py @@ -0,0 +1,243 @@ +import pytest +import torch + +from torchTextClassifiers.model.components import ( + AttentionConfig, + CategoricalForwardType, + CategoricalVariableNet, + ClassificationHead, + LabelAttentionConfig, + SentenceEmbedder, + SentenceEmbedderConfig, + TokenEmbedder, + TokenEmbedderConfig, +) +from torchTextClassifiers.model.model import TextClassificationModel + +BATCH = 4 +SEQ_LEN = 20 +EMB_DIM = 16 # divisible by 4 (n_head) and head_dim=4 is even (rotary) +VOCAB_SIZE = 100 +PADDING_IDX = 0 +NUM_CLASSES = 3 + + +@pytest.fixture +def input_ids(): + ids = torch.randint(1, VOCAB_SIZE, (BATCH, SEQ_LEN)) + ids[:, -2:] = PADDING_IDX + return ids + + +@pytest.fixture +def attention_mask(input_ids): + return (input_ids != PADDING_IDX).long() + + +@pytest.fixture +def token_embeddings(): + return torch.randn(BATCH, SEQ_LEN, EMB_DIM) + + +class TestTokenEmbedder: + def test_no_attention(self, input_ids, attention_mask): + embedder = TokenEmbedder( + TokenEmbedderConfig( + vocab_size=VOCAB_SIZE, embedding_dim=EMB_DIM, padding_idx=PADDING_IDX + ) + ) + out = embedder(input_ids, attention_mask) + assert out["token_embeddings"].shape == (BATCH, SEQ_LEN, EMB_DIM) + assert out["attention_mask"].shape == (BATCH, SEQ_LEN) + + def test_with_attention(self, input_ids, attention_mask): + embedder = TokenEmbedder( + TokenEmbedderConfig( + vocab_size=VOCAB_SIZE, + embedding_dim=EMB_DIM, + padding_idx=PADDING_IDX, + attention_config=AttentionConfig( + n_layers=2, n_head=4, n_kv_head=4, positional_encoding=False + ), + ) + ) + out = embedder(input_ids, attention_mask) + assert out["token_embeddings"].shape == (BATCH, SEQ_LEN, EMB_DIM) + + def test_with_rotary_positional_encoding(self, input_ids, attention_mask): + embedder = TokenEmbedder( + TokenEmbedderConfig( + vocab_size=VOCAB_SIZE, + embedding_dim=EMB_DIM, + padding_idx=PADDING_IDX, + attention_config=AttentionConfig( + n_layers=1, + n_head=4, + n_kv_head=4, + positional_encoding=True, + sequence_len=SEQ_LEN, + ), + ) + ) + out = embedder(input_ids, attention_mask) + assert out["token_embeddings"].shape == (BATCH, SEQ_LEN, EMB_DIM) + + def test_shape_mismatch_raises(self): + embedder = TokenEmbedder( + TokenEmbedderConfig( + vocab_size=VOCAB_SIZE, embedding_dim=EMB_DIM, padding_idx=PADDING_IDX + ) + ) + with pytest.raises(ValueError): + embedder( + torch.randint(1, VOCAB_SIZE, (BATCH, SEQ_LEN)), + torch.ones(BATCH, SEQ_LEN + 1, dtype=torch.long), + ) + + +class TestSentenceEmbedder: + @pytest.mark.parametrize("method", ["mean", "first", "last"]) + def test_aggregation_methods(self, token_embeddings, attention_mask, method): + embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method=method)) + out = embedder(token_embeddings, attention_mask) + assert out["sentence_embedding"].shape == (BATCH, EMB_DIM) + assert out["label_attention_matrix"] is None + + def test_label_attention_output_shape(self, token_embeddings, attention_mask): + embedder = SentenceEmbedder( + SentenceEmbedderConfig( + aggregation_method=None, + label_attention_config=LabelAttentionConfig( + n_head=4, num_classes=NUM_CLASSES, embedding_dim=EMB_DIM + ), + ) + ) + out = embedder(token_embeddings, attention_mask) + assert out["sentence_embedding"].shape == (BATCH, NUM_CLASSES, EMB_DIM) + assert out["label_attention_matrix"] is None + + def test_label_attention_matrix_returned(self, token_embeddings, attention_mask): + embedder = SentenceEmbedder( + SentenceEmbedderConfig( + aggregation_method=None, + label_attention_config=LabelAttentionConfig( + n_head=4, num_classes=NUM_CLASSES, embedding_dim=EMB_DIM + ), + ) + ) + out = embedder(token_embeddings, attention_mask, return_label_attention_matrix=True) + assert out["label_attention_matrix"].shape == (BATCH, 4, NUM_CLASSES, SEQ_LEN) + + def test_none_aggregation_without_label_attention_raises(self): + with pytest.raises(ValueError): + SentenceEmbedder(SentenceEmbedderConfig(aggregation_method=None)) + + +class TestCategoricalVariableNet: + def test_concatenate_all(self): + net = CategoricalVariableNet( + categorical_vocabulary_sizes=[4, 5], + categorical_embedding_dims=[3, 6], + ) + assert net.forward_type == CategoricalForwardType.CONCATENATE_ALL + assert net.output_dim == 9 + out = net(torch.randint(0, 3, (BATCH, 2))) + assert out.shape == (BATCH, 9) + + def test_average_and_concat(self): + net = CategoricalVariableNet( + categorical_vocabulary_sizes=[4, 5], + categorical_embedding_dims=8, + ) + assert net.forward_type == CategoricalForwardType.AVERAGE_AND_CONCAT + assert net.output_dim == 8 + out = net(torch.randint(0, 3, (BATCH, 2))) + assert out.shape == (BATCH, 8) + + def test_sum_to_text(self): + net = CategoricalVariableNet( + categorical_vocabulary_sizes=[4, 5], + categorical_embedding_dims=None, + text_embedding_dim=EMB_DIM, + ) + assert net.forward_type == CategoricalForwardType.SUM_TO_TEXT + assert net.output_dim == EMB_DIM + out = net(torch.randint(0, 3, (BATCH, 2))) + assert out.shape == (BATCH, EMB_DIM) + + def test_out_of_range_value_raises(self): + net = CategoricalVariableNet( + categorical_vocabulary_sizes=[4, 5], + categorical_embedding_dims=[3, 6], + ) + with pytest.raises(ValueError): + net(torch.tensor([[10, 1]] * BATCH)) # first feature value 10 >= vocab 4 + + +class TestTextClassificationModel: + def _token_embedder(self): + return TokenEmbedder( + TokenEmbedderConfig( + vocab_size=VOCAB_SIZE, embedding_dim=EMB_DIM, padding_idx=PADDING_IDX + ) + ) + + def _sentence_embedder(self, label_attention=False): + if label_attention: + return SentenceEmbedder( + SentenceEmbedderConfig( + aggregation_method=None, + label_attention_config=LabelAttentionConfig( + n_head=4, num_classes=NUM_CLASSES, embedding_dim=EMB_DIM + ), + ) + ) + return SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean")) + + def test_text_only(self, input_ids, attention_mask): + model = TextClassificationModel( + token_embedder=self._token_embedder(), + sentence_embedder=self._sentence_embedder(), + classification_head=ClassificationHead(input_dim=EMB_DIM, num_classes=NUM_CLASSES), + ) + logits = model(input_ids, attention_mask, torch.empty(BATCH, 0)) + assert logits.shape == (BATCH, NUM_CLASSES) + + def test_text_and_categorical(self, input_ids, attention_mask): + cat_net = CategoricalVariableNet( + categorical_vocabulary_sizes=[4, 5], + categorical_embedding_dims=[3, 6], + ) + model = TextClassificationModel( + token_embedder=self._token_embedder(), + sentence_embedder=self._sentence_embedder(), + categorical_variable_net=cat_net, + classification_head=ClassificationHead( + input_dim=EMB_DIM + cat_net.output_dim, num_classes=NUM_CLASSES + ), + ) + logits = model(input_ids, attention_mask, torch.randint(0, 3, (BATCH, 2))) + assert logits.shape == (BATCH, NUM_CLASSES) + + def test_label_attention_logits_and_matrix(self, input_ids, attention_mask): + model = TextClassificationModel( + token_embedder=self._token_embedder(), + sentence_embedder=self._sentence_embedder(label_attention=True), + classification_head=ClassificationHead(input_dim=EMB_DIM, num_classes=1), + ) + result = model( + input_ids, + attention_mask, + torch.empty(BATCH, 0), + return_label_attention_matrix=True, + ) + assert result["logits"].shape == (BATCH, NUM_CLASSES) + assert result["label_attention_matrix"].shape == (BATCH, 4, NUM_CLASSES, SEQ_LEN) + + def test_missing_sentence_embedder_raises(self): + with pytest.raises(ValueError): + TextClassificationModel( + token_embedder=self._token_embedder(), + sentence_embedder=None, + classification_head=ClassificationHead(input_dim=EMB_DIM, num_classes=NUM_CLASSES), + ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index d998989..fd1cc7c 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -11,8 +11,10 @@ CategoricalVariableNet, ClassificationHead, LabelAttentionConfig, - TextEmbedder, - TextEmbedderConfig, + SentenceEmbedder, + SentenceEmbedderConfig, + TokenEmbedder, + TokenEmbedderConfig, ) from torchTextClassifiers.tokenizers import NGramTokenizer from torchTextClassifiers.value_encoder import DictEncoder, ValueEncoder @@ -122,24 +124,29 @@ def run_full_pipeline( sequence_len=sequence_len, ) - # Create text embedder - text_embedder_config = TextEmbedderConfig( + # Create token embedder + token_embedder_config = TokenEmbedderConfig( vocab_size=vocab_size, embedding_dim=model_params["embedding_dim"], padding_idx=padding_idx, attention_config=attention_config, + ) + token_embedder = TokenEmbedder(token_embedder_config=token_embedder_config) + + # Create sentence embedder + sentence_embedder_config = SentenceEmbedderConfig( label_attention_config=( LabelAttentionConfig( n_head=attention_config.n_head, num_classes=num_classes, + embedding_dim=model_params["embedding_dim"], ) if label_attention_enabled else None ), + aggregation_method=None if label_attention_enabled else "mean", ) - - text_embedder = TextEmbedder(text_embedder_config=text_embedder_config) - text_embedder.init_weights() + sentence_embedder = SentenceEmbedder(sentence_embedder_config=sentence_embedder_config) # Create categorical variable net (vocab sizes from fitted encoder) categorical_var_net = CategoricalVariableNet( @@ -156,7 +163,8 @@ def run_full_pipeline( # Create model model = TextClassificationModel( - text_embedder=text_embedder, + token_embedder=token_embedder, + sentence_embedder=sentence_embedder, categorical_variable_net=categorical_var_net, classification_head=classification_head, ) diff --git a/torchTextClassifiers/model/components/__init__.py b/torchTextClassifiers/model/components/__init__.py index 5cad342..3db1a73 100644 --- a/torchTextClassifiers/model/components/__init__.py +++ b/torchTextClassifiers/model/components/__init__.py @@ -9,5 +9,5 @@ ) from .classification_head import ClassificationHead as ClassificationHead from .text_embedder import LabelAttentionConfig as LabelAttentionConfig -from .text_embedder import TextEmbedder as TextEmbedder -from .text_embedder import TextEmbedderConfig as TextEmbedderConfig +from .text_embedder import TokenEmbedder as TokenEmbedder, TokenEmbedderConfig as TokenEmbedderConfig +from .text_embedder import SentenceEmbedder as SentenceEmbedder, SentenceEmbedderConfig as SentenceEmbedderConfig \ No newline at end of file diff --git a/torchTextClassifiers/model/components/attention.py b/torchTextClassifiers/model/components/attention.py index 7c6474c..130e88f 100644 --- a/torchTextClassifiers/model/components/attention.py +++ b/torchTextClassifiers/model/components/attention.py @@ -35,7 +35,6 @@ class AttentionConfig: n_kv_head: int sequence_len: Optional[int] = None positional_encoding: bool = True - aggregation_method: str = "mean" # or 'last', or 'first' #### Attention Block ##### diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 9d5aaa5..be5e7fe 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -1,3 +1,4 @@ +import logging import math from dataclasses import dataclass from typing import Dict, Optional @@ -8,47 +9,58 @@ from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm +logger = logging.getLogger(__name__) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler()], +) + @dataclass class LabelAttentionConfig: n_head: int num_classes: int + embedding_dim: int @dataclass -class TextEmbedderConfig: +class TokenEmbedderConfig: vocab_size: int embedding_dim: int padding_idx: int attention_config: Optional[AttentionConfig] = None + + +@dataclass +class SentenceEmbedderConfig: + aggregation_method: Optional[str] = "mean" # or 'last', or 'first' label_attention_config: Optional[LabelAttentionConfig] = None -class TextEmbedder(nn.Module): - def __init__(self, text_embedder_config: TextEmbedderConfig): +class TokenEmbedder(nn.Module): + """ + A module that takes tokenized text and outputs dense vector representations (one for each token). + + """ + + cos: torch.Tensor + sin: torch.Tensor + + def __init__(self, token_embedder_config: TokenEmbedderConfig): super().__init__() - self.config = text_embedder_config + self.config = token_embedder_config - self.attention_config = text_embedder_config.attention_config + self.attention_config = token_embedder_config.attention_config if isinstance(self.attention_config, dict): self.attention_config = AttentionConfig(**self.attention_config) - # Normalize label_attention_config: allow dicts and convert them to LabelAttentionConfig - self.label_attention_config = text_embedder_config.label_attention_config - if isinstance(self.label_attention_config, dict): - self.label_attention_config = LabelAttentionConfig(**self.label_attention_config) - # Keep self.config in sync so downstream components (e.g., LabelAttentionClassifier) - # always see a LabelAttentionConfig instance rather than a raw dict. - self.config.label_attention_config = self.label_attention_config - - self.enable_label_attention = self.label_attention_config is not None - if self.enable_label_attention: - self.label_attention_module = LabelAttentionClassifier(self.config) - - self.vocab_size = text_embedder_config.vocab_size - self.embedding_dim = text_embedder_config.embedding_dim - self.padding_idx = text_embedder_config.padding_idx + self.vocab_size = token_embedder_config.vocab_size + self.embedding_dim = token_embedder_config.embedding_dim + self.padding_idx = token_embedder_config.padding_idx self.embedding_layer = nn.Embedding( embedding_dim=self.embedding_dim, @@ -57,7 +69,7 @@ def __init__(self, text_embedder_config: TextEmbedderConfig): ) if self.attention_config is not None: - self.attention_config.n_embd = text_embedder_config.embedding_dim + self.attention_config.n_embd: int = token_embedder_config.embedding_dim self.transformer = nn.ModuleDict( { "h": nn.ModuleList( @@ -103,10 +115,7 @@ def init_weights(self): for block in self.transformer.h: torch.nn.init.zeros_(block.mlp.c_proj.weight) torch.nn.init.zeros_(block.attn.c_proj.weight) - # init the rotary embeddings - head_dim = self.attention_config.n_embd // self.attention_config.n_head - cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) - self.cos, self.sin = cos, sin + # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations if self.embedding_layer.weight.device.type == "cuda": self.embedding_layer.to(dtype=torch.bfloat16) @@ -127,30 +136,7 @@ def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, - return_label_attention_matrix: bool = False, ) -> Dict[str, Optional[torch.Tensor]]: - """Converts input token IDs to their corresponding embeddings. - - Args: - input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized - attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens - return_label_attention_matrix (bool): Whether to return the label attention matrix. - - Returns: - dict: A dictionary with the following keys: - - - "sentence_embedding" (torch.Tensor): Text embeddings of shape - (batch_size, embedding_dim) if ``self.enable_label_attention`` is False, - else (batch_size, num_classes, embedding_dim), where ``num_classes`` - is the number of label classes. - - - "label_attention_matrix" (Optional[torch.Tensor]): Label attention - matrix of shape (batch_size, n_head, num_classes, seq_len) if - ``return_label_attention_matrix`` is True and label attention is - enabled, otherwise ``None``. The dimensions correspond to - (batch_size, attention heads, label classes, sequence length). - """ - encoded_text = input_ids # clearer name if encoded_text.dtype != torch.long: encoded_text = encoded_text.to(torch.long) @@ -181,92 +167,9 @@ def forward( token_embeddings = norm(token_embeddings) - out = self._get_sentence_embedding( - token_embeddings=token_embeddings, - attention_mask=attention_mask, - return_label_attention_matrix=return_label_attention_matrix, - ) - - text_embedding = out["sentence_embedding"] - label_attention_matrix = out["label_attention_matrix"] return { - "sentence_embedding": text_embedding, - "label_attention_matrix": label_attention_matrix, - } - - def _get_sentence_embedding( - self, - token_embeddings: torch.Tensor, - attention_mask: torch.Tensor, - return_label_attention_matrix: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - """ - Compute sentence embedding from embedded tokens - "remove" second dimension. - - Args (output from dataset collate_fn): - token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text - attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens - return_label_attention_matrix (bool): Whether to compute and return the label attention matrix - Returns: - Dict[str, Optional[torch.Tensor]]: A dictionary containing: - - 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled - - 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None - """ - - # average over non-pad token embeddings - # attention mask has 1 for non-pad tokens and 0 for pad token positions - - # mask pad-tokens - - if self.attention_config is not None: - if self.attention_config.aggregation_method is not None: # default is "mean" - if self.attention_config.aggregation_method == "first": - return { - "sentence_embedding": token_embeddings[:, 0, :], - "label_attention_matrix": None, - } - elif self.attention_config.aggregation_method == "last": - lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1 - return { - "sentence_embedding": token_embeddings[ - torch.arange(token_embeddings.size(0)), - lengths - 1, - :, - ], - "label_attention_matrix": None, - } - else: - if self.attention_config.aggregation_method != "mean": - raise ValueError( - f"Unknown aggregation method: {self.attention_config.aggregation_method}. Supported methods are 'mean', 'first', 'last'." - ) - - assert self.attention_config is None or self.attention_config.aggregation_method == "mean" - - if self.enable_label_attention: - label_attention_result = self.label_attention_module( - token_embeddings, - attention_mask=attention_mask, - compute_attention_matrix=return_label_attention_matrix, - ) - sentence_embedding = label_attention_result[ - "sentence_embedding" - ] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix - label_attention_matrix = label_attention_result["attention_matrix"] - - else: # sentence embedding = mean of (non-pad) token embeddings - mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1) - masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim) - sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp( - min=1.0 - ) # avoid division by zero - - sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0) - label_attention_matrix = None - - return { - "sentence_embedding": sentence_embedding, - "label_attention_matrix": label_attention_matrix, + "token_embeddings": token_embeddings, + "attention_mask": attention_mask, } def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): @@ -291,20 +194,23 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No return cos, sin -class LabelAttentionClassifier(nn.Module): +class LabelAttention(nn.Module): """ A head for aggregating token embeddings into label-specific sentence embeddings using cross-attention mechanism. Labels are queries that attend over token embeddings (keys and values) to produce label-specific embeddings. """ - def __init__(self, config: TextEmbedderConfig): + def __init__(self, label_attention_config: LabelAttentionConfig): super().__init__() - label_attention_config = config.label_attention_config - self.embedding_dim = config.embedding_dim + if label_attention_config is None: + raise ValueError("label_attention_config must be provided to use LabelAttention.") + + self.label_attention_config = label_attention_config self.num_classes = label_attention_config.num_classes self.n_head = label_attention_config.n_head + self.embedding_dim = label_attention_config.embedding_dim # Validate head configuration self.head_dim = self.embedding_dim // self.n_head @@ -399,3 +305,111 @@ def forward( attention_matrix = torch.softmax(attention_scores, dim=-1) return {"sentence_embedding": y, "attention_matrix": attention_matrix} + + +class SentenceEmbedder(nn.Module): + def __init__(self, sentence_embedder_config: SentenceEmbedderConfig): + super().__init__() + """ + A module to aggregate token embeddings. + + Four modes are possible: + - aggregation_method="mean" (default): token embeddings are averaged + - aggregation_method="first": sentence embedding is the first token's embedding (commin in BERT-like models ([CLS] token)) + - aggregation_method="last": sentence embedding is the last token's embedding (commin in GPT-like models) + - aggregation_method=None: in that case you need to provide a label attention + """ + + self.config = sentence_embedder_config + self.label_attention_config = sentence_embedder_config.label_attention_config + self.aggregation_method = sentence_embedder_config.aggregation_method + + if isinstance(self.label_attention_config, dict): + self.label_attention_config = LabelAttentionConfig(**self.label_attention_config) + # Keep self.sentence_embedder_config in sync so downstream components (e.g., LabelAttentionClassifier) + # always see a LabelAttentionConfig instance rather than a raw dict. + self.sentence_embedder_config.label_attention_config: LabelAttentionConfig = ( + self.label_attention_config + ) + + if self.label_attention_config is not None: + self.label_attention_module = LabelAttention( + label_attention_config=self.label_attention_config + ) + if self.aggregation_method is not None: + logger.info( + "Warning: aggregation_method is ignored when label_attention_config is provided, since label attention produces label-specific sentence embeddings without further aggregation." + ) + self.aggregation_method = None # override to avoid confusion + + if self.aggregation_method not in (None, "mean", "first", "last"): + raise ValueError( + f"Unsupported aggregation method: {self.aggregation_method}. Supported methods are None, 'mean', 'first', 'last'." + ) + if self.aggregation_method is None: + if self.label_attention_config is None: + raise ValueError( + "aggregation_method cannot be None when label_attention_config is not provided, since we need some way to aggregate token embeddings into a sentence embedding. Please specify an aggregation method (e.g., 'mean') or provide a label_attention_config to use label attention for aggregation." + ) + + def forward( + self, + token_embeddings: torch.Tensor, + attention_mask: torch.Tensor, + return_label_attention_matrix: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + """ + Compute sentence embedding from embedded tokens - "remove" second dimension. + + Args (output from dataset collate_fn): + token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text + attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens + return_label_attention_matrix (bool): Whether to compute and return the label attention matrix + Returns: + Dict[str, Optional[torch.Tensor]]: A dictionary containing: + - 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled + - 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None + """ + if self.aggregation_method is not None: # default is "mean" + if self.aggregation_method == "first": + return { + "sentence_embedding": token_embeddings[:, 0, :], + "label_attention_matrix": None, + } + elif self.aggregation_method == "last": + lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1 + return { + "sentence_embedding": token_embeddings[ + torch.arange(token_embeddings.size(0)), + lengths - 1, + :, + ], + "label_attention_matrix": None, + } + else: # mean + mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1) + masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim) + sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp( + min=1.0 + ) # avoid division by zero + + sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0) + return { + "sentence_embedding": sentence_embedding, + "label_attention_matrix": None, + } + + else: + label_attention_result = self.label_attention_module( + token_embeddings, + attention_mask=attention_mask, + compute_attention_matrix=return_label_attention_matrix, + ) + sentence_embedding = label_attention_result[ + "sentence_embedding" + ] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix + label_attention_matrix = label_attention_result["attention_matrix"] + return { + "sentence_embedding": sentence_embedding, + "label_attention_matrix": label_attention_matrix, + } diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index 3630e62..0f758ca 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -15,7 +15,8 @@ CategoricalForwardType, CategoricalVariableNet, ClassificationHead, - TextEmbedder, + SentenceEmbedder, + TokenEmbedder, ) from torchTextClassifiers.model.components.attention import norm @@ -38,12 +39,11 @@ class TextClassificationModel(nn.Module): - """FastText Pytorch Model.""" - def __init__( self, classification_head: ClassificationHead, - text_embedder: Optional[TextEmbedder] = None, + token_embedder: Optional[TokenEmbedder] = None, + sentence_embedder: Optional[SentenceEmbedder] = None, categorical_variable_net: Optional[CategoricalVariableNet] = None, ): """ @@ -51,14 +51,23 @@ def __init__( Args: classification_head (ClassificationHead): The classification head module. - text_embedder (Optional[TextEmbedder]): The text embedding module. + token_embedder (Optional[TextEmbedder]): The text embedding module. If not provided, assumes that input text is already embedded (as tensors) and directly passed to the classification head. + sentence_embedder: categorical_variable_net (Optional[CategoricalVariableNet]): The categorical variable network module. If not provided, assumes no categorical variables are used. """ super().__init__() - self.text_embedder = text_embedder + self.token_embedder = token_embedder + self.sentence_embedder = sentence_embedder + + if self.token_embedder is not None: + self.token_embedder.init_weights() + if self.sentence_embedder is None: + raise ValueError( + "You have provided a TokenEmbedder but no SentenceEmbedder: please provide one." + ) self.categorical_variable_net = categorical_variable_net if not self.categorical_variable_net: @@ -69,42 +78,40 @@ def __init__( self._validate_component_connections() torch.nn.init.zeros_(self.classification_head.net.weight) - if self.text_embedder is not None: - self.text_embedder.init_weights() def _validate_component_connections(self): - def _check_text_categorical_connection(self, text_embedder, cat_var_net): + def _check_text_categorical_connection(self, token_embedder, cat_var_net): if cat_var_net.forward_type == CategoricalForwardType.SUM_TO_TEXT: - if text_embedder.embedding_dim != cat_var_net.output_dim: + if token_embedder.embedding_dim != cat_var_net.output_dim: raise ValueError( "Text embedding dimension must match categorical variable embedding dimension." ) - self.expected_classification_head_input_dim = text_embedder.embedding_dim + self.expected_classification_head_input_dim = token_embedder.embedding_dim else: self.expected_classification_head_input_dim = ( - text_embedder.embedding_dim + cat_var_net.output_dim + token_embedder.embedding_dim + cat_var_net.output_dim ) - if self.text_embedder: + if self.token_embedder: if self.categorical_variable_net: _check_text_categorical_connection( - self, self.text_embedder, self.categorical_variable_net + self, self.token_embedder, self.categorical_variable_net ) else: - self.expected_classification_head_input_dim = self.text_embedder.embedding_dim + self.expected_classification_head_input_dim = self.token_embedder.embedding_dim if self.expected_classification_head_input_dim != self.classification_head.input_dim: raise ValueError( "Classification head input dimension does not match expected dimension from text embedder and categorical variable net." ) - if self.text_embedder.enable_label_attention: + if self.sentence_embedder.label_attention_config is not None: self.enable_label_attention = True if self.classification_head.num_classes != 1: raise ValueError( "Label attention is enabled. TextEmbedder outputs a (num_classes, embedding_dim) tensor, so the ClassificationHead should have an output dimension of 1." ) # if enable_label_attention is True, label_attention_config exists - and contains num_classes necessarily - self.num_classes = self.text_embedder.config.label_attention_config.num_classes + self.num_classes = self.sentence_embedder.label_attention_config.num_classes else: self.enable_label_attention = False self.num_classes = self.classification_head.num_classes @@ -140,23 +147,24 @@ def forward( """ encoded_text = input_ids # clearer name label_attention_matrix = None - if self.text_embedder is None: + if self.token_embedder is None: x_text = encoded_text.float() if return_label_attention_matrix: raise ValueError( - "return_label_attention_matrix=True requires a text_embedder with label attention enabled" + "return_label_attention_matrix=True requires a token_embedder with label attention enabled" ) else: - text_embed_output = self.text_embedder( + token_embed_output = self.token_embedder( input_ids=encoded_text, attention_mask=attention_mask, - return_label_attention_matrix=return_label_attention_matrix, ) - x_text = text_embed_output["sentence_embedding"] - if isinstance(return_label_attention_matrix, torch.Tensor): - return_label_attention_matrix = return_label_attention_matrix[0].item() + x_token = token_embed_output["token_embeddings"] + sentence_embedding_output = self.sentence_embedder( + x_token, attention_mask, return_label_attention_matrix=return_label_attention_matrix + ) + x_text = sentence_embedding_output["sentence_embedding"] if return_label_attention_matrix: - label_attention_matrix = text_embed_output["label_attention_matrix"] + label_attention_matrix = sentence_embedding_output["label_attention_matrix"] if self.categorical_variable_net: x_cat = self.categorical_variable_net(categorical_vars) diff --git a/torchTextClassifiers/test copy.py b/torchTextClassifiers/test copy.py deleted file mode 100644 index 09dc5e6..0000000 --- a/torchTextClassifiers/test copy.py +++ /dev/null @@ -1,107 +0,0 @@ -import numpy as np -import torch - -from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers -from torchTextClassifiers.tokenizers import HuggingFaceTokenizer - -# Note: %load_ext autoreload and %autoreload 2 are specific to IPython/Notebooks -# and are omitted here for a standard Python script. - -# ========================================== -# 1. Ragged-lists approach -# ========================================== - -# In multilabel classification, each instance can be assigned multiple labels simultaneously. -# Let's use fake data where labels is a list of lists (ragged array). -sample_text_data = [ - "This is a positive example", - "This is a negative example", - "Another positive case", - "Another negative case", - "Good example here", - "Bad example here", -] - -# Each inner list contains labels for the corresponding instance -labels_ragged = [[0, 1, 5], [0, 4], [1, 5], [0, 1, 4], [1, 5], [0]] - -# Note: labels_ragged is a "jagged array." -# np.array(labels_ragged) would not work directly as a standard numeric matrix. -# However, torchTextClassifiers handles this directly. - -# Load a pre-trained tokenizer -tokenizer = HuggingFaceTokenizer.load_from_pretrained( - "google-bert/bert-base-uncased", output_dim=126 -) - -X = np.array(sample_text_data) -Y_ragged = labels_ragged - -# Configure the model and training -# We use BCEWithLogitsLoss for multilabel tasks to treat each label -# as a separate binary classification problem. -embedding_dim = 96 -num_classes = max(max(label_list) for label_list in labels_ragged) + 1 - -model_config = ModelConfig( - embedding_dim=embedding_dim, - num_classes=num_classes, -) - -training_config = TrainingConfig( - lr=1e-3, - batch_size=4, - num_epochs=1, - loss=torch.nn.BCEWithLogitsLoss(), # Essential for multilabel - raw_labels=False, - raw_categorical_inputs=False, -) - -# Initialize the classifier with ragged_multilabel=True -ttc_ragged = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config, - ragged_multilabel=True, # Key for ragged list input! -) - -print("Starting training with ragged labels...") -ttc_ragged.train( - X_train=X, - y_train=Y_ragged, - training_config=training_config, -) - -# Behind the scenes, the ragged lists are converted into a binary matrix (one-hot version). - -# ========================================== -# 2. One-hot / multidimensional output approach -# ========================================== - -# You can also provide a one-hot/multidimensional array (or float probabilities). -# Here, each row is a vector of size equal to the number of labels. -labels_one_hot = [ - [1.0, 1.0, 0.0, 0.0, 0.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], - [1.0, 1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], -] -Y_one_hot = np.array(labels_one_hot) - -# When using one-hot/dense arrays, set ragged_multilabel=False (default) -ttc_dense = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config, -) - -print("\nStarting training with one-hot labels...") -ttc_dense.train( - X_train=X, - y_train=Y_one_hot, - training_config=training_config, -) - -# Final Note: -# - Use BCEWithLogitsLoss for multilabel settings. -# - Use CrossEntropyLoss for "soft" multiclass (where probabilities sum to 1). diff --git a/torchTextClassifiers/test.py b/torchTextClassifiers/test.py deleted file mode 100644 index 2758b28..0000000 --- a/torchTextClassifiers/test.py +++ /dev/null @@ -1,95 +0,0 @@ -import numpy as np -import pandas as pd -from sklearn.preprocessing import LabelEncoder - -from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers -from torchTextClassifiers.tokenizers import WordPieceTokenizer -from torchTextClassifiers.value_encoder import DictEncoder, ValueEncoder - -sample_text_data = [ - "This is a positive example", - "This is a negative example", - "Another positive case", - "Another negative case", - "Good example here", - "Bad example here", -] - -categorical_data = np.array( - [ - ["cat", "red"], - ["dog", "blue"], - ["cat", "red"], - ["dog", "blue"], - ["cat", "red"], - ["dog", "blue"], - ] -) - -labels = np.array(["positive", "negative", "positive", "negative", "positive", "negative"]) - - -df = pd.DataFrame( - { - "text": sample_text_data, - "category": categorical_data[:, 0], - "color": categorical_data[:, 1], - "label": labels, - } -) -vocab_size = 10 -tokenizer = WordPieceTokenizer(vocab_size, output_dim=50) -tokenizer.train(sample_text_data) - -encoders = {} -# category : DictEncoder (ours) -feature = "category" -mapping = {val: idx for idx, val in enumerate(df[feature].unique())} -encoders[feature] = DictEncoder(mapping) - -# color: LabelEncoder (sklearn) -le = LabelEncoder() -le.fit(df["color"]) -encoders["color"] = le - -feature = "label" -le_label = LabelEncoder() -le_label.fit(df[feature]) -label_encoder = le_label - -# OR you can also use DictEncoder -# dict_mapping = {val: idx for idx, val in enumerate(df[feature].unique())} -# label_encoder = DictEncoder(dict_mapping) - -value_encoder = ValueEncoder(label_encoder, encoders) - - -model_config = ModelConfig( - embedding_dim=10, - categorical_embedding_dims=[5, 5], -) -training_config = TrainingConfig( - num_epochs=1, - batch_size=2, - lr=1e-3, -) - -ttc = torchTextClassifiers( - tokenizer=tokenizer, - model_config=model_config, - value_encoder=value_encoder, -) - -ttc.train( - X_train=df[["text", "category", "color"]].values, - y_train=df["label"].values, - training_config=training_config, -) - -torchTextClassifiers.load("my_ttc/") - -ttc.predict( - X_test=df[["text", "category", "color"]].values, - raw_categorical_inputs=True, # Set to True since we're providing raw categorical values - top_k=2, -) diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index fea5194..b5aeb10 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -30,8 +30,10 @@ CategoricalVariableNet, ClassificationHead, LabelAttentionConfig, - TextEmbedder, - TextEmbedderConfig, + SentenceEmbedder, + SentenceEmbedderConfig, + TokenEmbedder, + TokenEmbedderConfig, ) from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput from torchTextClassifiers.value_encoder import ValueEncoder @@ -56,6 +58,7 @@ class ModelConfig: categorical_embedding_dims: Optional[Union[List[int], int]] = None attention_config: Optional[AttentionConfig] = None n_heads_label_attention: Optional[int] = None + aggregation_method: Optional[str] = "mean" def to_dict(self) -> Dict[str, Any]: return asdict(self) @@ -177,26 +180,33 @@ def __init__( self.enable_label_attention = model_config.n_heads_label_attention is not None if self.tokenizer.output_vectorized: - self.text_embedder = None + self.token_embedder = None logger.info( "Tokenizer outputs vectorized tokens; skipping TextEmbedder initialization." ) self.embedding_dim = self.tokenizer.output_dim else: - text_embedder_config = TextEmbedderConfig( + token_embedder_config = TokenEmbedderConfig( vocab_size=self.vocab_size, embedding_dim=self.embedding_dim, padding_idx=tokenizer.padding_idx, attention_config=model_config.attention_config, + ) + sentence_embedder_config = SentenceEmbedderConfig( label_attention_config=LabelAttentionConfig( n_head=model_config.n_heads_label_attention, num_classes=model_config.num_classes, + embedding_dim=self.embedding_dim, ) if self.enable_label_attention else None, + aggregation_method=model_config.aggregation_method, + ) + self.token_embedder = TokenEmbedder( + token_embedder_config=token_embedder_config, ) - self.text_embedder = TextEmbedder( - text_embedder_config=text_embedder_config, + self.sentence_embedder = SentenceEmbedder( + sentence_embedder_config=sentence_embedder_config ) classif_head_input_dim = self.embedding_dim @@ -221,7 +231,8 @@ def __init__( ) self.pytorch_model = TextClassificationModel( - text_embedder=self.text_embedder, + token_embedder=self.token_embedder, + sentence_embedder=self.sentence_embedder, categorical_variable_net=self.categorical_var_net, classification_head=self.classification_head, ) @@ -279,7 +290,7 @@ def train( X_train, y_train = self._check_XY( X_train, y_train, training_config.raw_categorical_inputs, training_config.raw_labels ) - print(X_train, y_train) + if X_val is not None: assert y_val is not None, "y_val must be provided if X_val is provided." if y_val is not None: @@ -572,7 +583,7 @@ def predict( if explain: return_offsets_mapping = True # to be passed to the tokenizer return_word_ids = True - if self.pytorch_model.text_embedder is None: + if self.pytorch_model.token_embedder is None: raise RuntimeError( "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs." ) @@ -583,7 +594,7 @@ def predict( "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'." ) lig = LayerIntegratedGradients( - self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer + self.pytorch_model, self.pytorch_model.token_embedder.embedding_layer ) # initialize a Captum layer gradient integrator if explain_with_label_attention: if not self.enable_label_attention: