Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 99 additions & 27 deletions docs/source/api/components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,27 @@ Modular torch.nn.Module components for building custom architectures.
Text Embedding
--------------

TextEmbedder
~~~~~~~~~~~~
Text embedding is split into two composable stages:

Embeds text tokens with optional self-attention.
1. **TokenEmbedder** — maps each token to a dense vector (with optional self-attention). Output: ``(batch, seq_len, embedding_dim)``.
2. **SentenceEmbedder** — aggregates token vectors into a sentence embedding. Output: ``(batch, embedding_dim)`` or ``(batch, num_classes, embedding_dim)`` with label attention.

.. autoclass:: torchTextClassifiers.model.components.text_embedder.TextEmbedder
TokenEmbedder
~~~~~~~~~~~~~

Embeds tokenized text with optional self-attention.

.. autoclass:: torchTextClassifiers.model.components.text_embedder.TokenEmbedder
:members:
:undoc-members:
:show-inheritance:

TextEmbedderConfig
~~~~~~~~~~~~~~~~~~
TokenEmbedderConfig
~~~~~~~~~~~~~~~~~~~

Configuration for TextEmbedder.
Configuration for TokenEmbedder.

.. autoclass:: torchTextClassifiers.model.components.text_embedder.TextEmbedderConfig
.. autoclass:: torchTextClassifiers.model.components.text_embedder.TokenEmbedderConfig
:members:
:undoc-members:
:show-inheritance:
Expand All @@ -32,31 +37,90 @@ Example:

.. code-block:: python

from torchTextClassifiers.model.components import TextEmbedder, TextEmbedderConfig
from torchTextClassifiers.model.components import (
TokenEmbedder, TokenEmbedderConfig, AttentionConfig,
)

# Simple text embedder
config = TextEmbedderConfig(
# Simple token embedder (no self-attention)
config = TokenEmbedderConfig(
vocab_size=5000,
embedding_dim=128,
attention_config=None
padding_idx=0,
)
embedder = TextEmbedder(config)
token_embedder = TokenEmbedder(config)
out = token_embedder(input_ids, attention_mask)
# out["token_embeddings"]: (batch, seq_len, 128)

# With self-attention
from torchTextClassifiers.model.components import AttentionConfig

attention_config = AttentionConfig(
n_embd=128,
n_layers=2,
n_head=4,
n_layer=2,
dropout=0.1
n_kv_head=4,
positional_encoding=False,
)
config = TextEmbedderConfig(
config = TokenEmbedderConfig(
vocab_size=5000,
embedding_dim=128,
attention_config=attention_config
padding_idx=0,
attention_config=attention_config,
)
token_embedder = TokenEmbedder(config)

SentenceEmbedder
~~~~~~~~~~~~~~~~

Aggregates per-token embeddings into a single sentence embedding.

.. autoclass:: torchTextClassifiers.model.components.text_embedder.SentenceEmbedder
:members:
:undoc-members:
:show-inheritance:

SentenceEmbedderConfig
~~~~~~~~~~~~~~~~~~~~~~

Configuration for SentenceEmbedder.

.. autoclass:: torchTextClassifiers.model.components.text_embedder.SentenceEmbedderConfig
:members:
:undoc-members:
:show-inheritance:

LabelAttentionConfig
~~~~~~~~~~~~~~~~~~~~

Configuration for the label-attention aggregation mode.

.. autoclass:: torchTextClassifiers.model.components.text_embedder.LabelAttentionConfig
:members:
:undoc-members:
:show-inheritance:

Example:

.. code-block:: python

from torchTextClassifiers.model.components import (
SentenceEmbedder, SentenceEmbedderConfig,
LabelAttentionConfig,
)
embedder = TextEmbedder(config)

# Mean-pooling (default)
sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean"))
out = sentence_embedder(token_embeddings, attention_mask)
# out["sentence_embedding"]: (batch, 128)

# Label attention — one embedding per class
sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(
aggregation_method=None,
label_attention_config=LabelAttentionConfig(
n_head=4,
num_classes=6,
embedding_dim=128,
),
))
out = sentence_embedder(token_embeddings, attention_mask)
# out["sentence_embedding"]: (batch, num_classes, 128)

Categorical Features
--------------------
Expand Down Expand Up @@ -246,22 +310,31 @@ Components can be composed to create custom architectures:

.. code-block:: python

import torch
import torch.nn as nn
from torchTextClassifiers.model.components import (
TextEmbedder, CategoricalVariableNet, ClassificationHead
TokenEmbedder, TokenEmbedderConfig,
SentenceEmbedder, SentenceEmbedderConfig,
CategoricalVariableNet, ClassificationHead,
)

class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.text_embedder = TextEmbedder(text_config)
self.token_embedder = TokenEmbedder(TokenEmbedderConfig(
vocab_size=5000, embedding_dim=128, padding_idx=0,
))
self.sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig())
self.cat_net = CategoricalVariableNet(...)
self.head = ClassificationHead(...)

def forward(self, input_ids, categorical_data):
text_features = self.text_embedder(input_ids)
def forward(self, input_ids, attention_mask, categorical_data):
token_out = self.token_embedder(input_ids, attention_mask)
sent_out = self.sentence_embedder(
token_out["token_embeddings"], token_out["attention_mask"]
)
cat_features = self.cat_net(categorical_data)
combined = torch.cat([text_features, cat_features], dim=1)
combined = torch.cat([sent_out["sentence_embedding"], cat_features], dim=1)
return self.head(combined)

See Also
Expand All @@ -270,4 +343,3 @@ See Also
* :doc:`model` - How components are used in models
* :doc:`../architecture/overview` - Architecture explanation
* :doc:`configs` - ModelConfig for component configuration

5 changes: 3 additions & 2 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The API is organized into several modules:
* :doc:`wrapper` - High-level torchTextClassifiers wrapper class
* :doc:`configs` - Configuration classes (ModelConfig, TrainingConfig)
* :doc:`tokenizers` - Text tokenization (NGram, WordPiece, HuggingFace)
* :doc:`components` - Model components (TextEmbedder, CategoricalVariableNet, etc.)
* :doc:`components` - Model components (TokenEmbedder, SentenceEmbedder, CategoricalVariableNet, etc.)
* :doc:`model` - Core PyTorch models
* :doc:`dataset` - Dataset classes for data loading

Expand All @@ -30,7 +30,8 @@ Most Used Classes
Architecture Components
~~~~~~~~~~~~~~~~~~~~~~~

* :class:`torchTextClassifiers.model.components.TextEmbedder` - Text embedding layer
* :class:`torchTextClassifiers.model.components.text_embedder.TokenEmbedder` - Token embedding layer
* :class:`torchTextClassifiers.model.components.text_embedder.SentenceEmbedder` - Sentence aggregation layer
* :class:`torchTextClassifiers.model.components.CategoricalVariableNet` - Categorical features
* :class:`torchTextClassifiers.model.components.ClassificationHead` - Classification layer
* :class:`torchTextClassifiers.model.components.Attention.AttentionConfig` - Attention configuration
Expand Down
36 changes: 20 additions & 16 deletions docs/source/api/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,53 @@ Core PyTorch nn.Module combining all components.

**Architecture:**

The model combines three main components:
The model combines four main components:

1. **TextEmbedder**: Converts tokens to embeddings
2. **CategoricalVariableNet** (optional): Handles categorical features
3. **ClassificationHead**: Produces class logits
1. **TokenEmbedder**: Maps each token to a dense vector (with optional self-attention)
2. **SentenceEmbedder**: Aggregates token vectors into a sentence representation
3. **CategoricalVariableNet** (optional): Handles categorical features
4. **ClassificationHead**: Produces class logits

Example:

.. code-block:: python

from torchTextClassifiers.model import TextClassificationModel
from torchTextClassifiers.model.components import (
TextEmbedder, TextEmbedderConfig,
CategoricalVariableNet, CategoricalForwardType,
ClassificationHead
TokenEmbedder, TokenEmbedderConfig,
SentenceEmbedder, SentenceEmbedderConfig,
CategoricalVariableNet,
ClassificationHead,
)

# Create components
text_embedder = TextEmbedder(TextEmbedderConfig(
token_embedder = TokenEmbedder(TokenEmbedderConfig(
vocab_size=5000,
embedding_dim=128
embedding_dim=128,
padding_idx=0,
))
sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean"))

cat_net = CategoricalVariableNet(
vocabulary_sizes=[10, 20],
embedding_dims=[8, 16],
forward_type=CategoricalForwardType.AVERAGE_AND_CONCAT
categorical_vocabulary_sizes=[10, 20],
categorical_embedding_dims=[8, 16],
)

classification_head = ClassificationHead(
input_dim=128 + 24, # text_dim + cat_dim
num_classes=5
num_classes=5,
)

# Combine into model
model = TextClassificationModel(
text_embedder=text_embedder,
token_embedder=token_embedder,
sentence_embedder=sentence_embedder,
categorical_variable_net=cat_net,
classification_head=classification_head
classification_head=classification_head,
)

# Forward pass
logits = model(input_ids, categorical_data)
logits = model(input_ids, attention_mask, categorical_data)

PyTorch Lightning Module
-------------------------
Expand Down
Loading
Loading