Skip to content

Commit 4f392ac

Browse files
authored
Update nn.aggr documentation (#5099)
* initial commit * changelog * update
1 parent 2efd97a commit 4f392ac

File tree

10 files changed

+42
-56
lines changed

10 files changed

+42
-56
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4343
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
4444
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
4545
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
46-
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033]), [#5085](https://github.com/pyg-team/pytorch_geometric/pull/5085), [#5097](https://github.com/pyg-team/pytorch_geometric/pull/5097))
46+
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033]), [#5085](https://github.com/pyg-team/pytorch_geometric/pull/5085), [#5097](https://github.com/pyg-team/pytorch_geometric/pull/5097), [#5099](https://github.com/pyg-team/pytorch_geometric/pull/5099))
4747
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
4848
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
4949
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))

docs/source/modules/nn.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ Convolutional Layers
2525
{{ cls }}
2626
{% endfor %}
2727

28-
.. automodule:: torch_geometric.nn.conv.message_passing
28+
.. autoclass:: torch_geometric.nn.conv.MessagePassing
2929
:members:
30-
:exclude-members: training
3130

3231
.. automodule:: torch_geometric.nn.conv
3332
:members:
@@ -49,7 +48,7 @@ Aggregation Operators
4948
{{ cls }}
5049
{% endfor %}
5150

52-
.. automodule:: torch_geometric.nn.aggr.base
51+
.. autoclass:: torch_geometric.nn.aggr.Aggregation
5352
:members:
5453

5554
.. automodule:: torch_geometric.nn.aggr

examples/equilibrium_median.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
import torch
1313

14-
from torch_geometric.nn.aggr import EquilibriumAggregation
14+
from torch_geometric.nn import EquilibriumAggregation
1515

1616
input_size = 100
1717
steps = 10000000

test/nn/aggr/test_equilibrium.py

-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
@pytest.mark.parametrize('iter', [0, 1, 5])
88
@pytest.mark.parametrize('alpha', [0, .1, 5])
99
def test_equilibrium(iter, alpha):
10-
1110
batch_size = 10
1211
feature_channels = 3
1312
output_channels = 2
@@ -30,7 +29,6 @@ def test_equilibrium(iter, alpha):
3029
@pytest.mark.parametrize('iter', [0, 1, 5])
3130
@pytest.mark.parametrize('alpha', [0, .1, 5])
3231
def test_equilibrium_batch(iter, alpha):
33-
3432
batch_1, batch_2 = 4, 6
3533
feature_channels = 3
3634
output_channels = 2

torch_geometric/nn/aggr/attention.py

-8
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,6 @@ class AttentionalAggregation(Aggregation):
3232
shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]`
3333
before combining them with the attention scores, *e.g.*, defined by
3434
:class:`torch.nn.Sequential`. (default: :obj:`None`)
35-
36-
Shapes:
37-
- **input:**
38-
node features :math:`(|\mathcal{V}|, F)`,
39-
batch vector :math:`(|\mathcal{V}|)` *(optional)*
40-
- **output:**
41-
graph features :math:`(|\mathcal{G}|, F)` where
42-
:math:`|\mathcal{G}|` denotes the number of graphs in the batch
4335
"""
4436
def __init__(self, gate_nn: torch.nn.Module,
4537
nn: Optional[torch.nn.Module] = None):

torch_geometric/nn/aggr/base.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,16 @@
88

99

1010
class Aggregation(torch.nn.Module):
11-
r"""An abstract base class for implementing custom aggregations."""
11+
r"""An abstract base class for implementing custom aggregations.
12+
13+
Shapes:
14+
- **input:**
15+
node features :math:`(|\mathcal{V}|, F_{in})` or edge features
16+
:math:`(|\mathcal{E}|, F_{in})`,
17+
index vector :math:`(|\mathcal{V}|)` or :math:`(|\mathcal{E}|)`,
18+
- **output:** graph features :math:`(|\mathcal{G}|, F_{out})` or node
19+
features :math:`(|\mathcal{V}|, F_{out})`
20+
"""
1221

1322
# @abstractmethod
1423
def forward(self, x: Tensor, index: Optional[Tensor] = None,

torch_geometric/nn/aggr/equilibrium.py

+22-27
Original file line numberDiff line numberDiff line change
@@ -102,31 +102,30 @@ def forward(
102102

103103

104104
class EquilibriumAggregation(Aggregation):
105-
r"""
106-
The graph global pooling layer from the
107-
`"Equilibrium Aggregation: Encoding Sets via Optimization"
108-
<https://arxiv.org/abs/2202.12795>`_ paper.
109-
This output of this layer :math:`\mathbf{y}` is defined implicitly by
110-
defining a potential function :math:`F(\mathbf{x}, \mathbf{y})`
111-
and regulatization function :math:`R(\mathbf{y})` and the condition
105+
r"""The equilibrium aggregation layer from the `"Equilibrium Aggregation:
106+
Encoding Sets via Optimization" <https://arxiv.org/abs/2202.12795>`_ paper.
107+
The output of this layer :math:`\mathbf{y}` is defined implicitly via a
108+
potential function :math:`F(\mathbf{x}, \mathbf{y})`, a regularization term
109+
:math:`R(\mathbf{y})`, and the condition
112110
113111
.. math::
114-
\mathbf{y} = \min_\mathbf{y} R(\mathbf{y}) +
115-
\sum_{i} F(\mathbf{x}_i, \mathbf{y})
112+
\mathbf{y} = \min_\mathbf{y} R(\mathbf{y}) + \sum_{i}
113+
F(\mathbf{x}_i, \mathbf{y}).
116114
117-
This implementation uses a ResNet Like model for the potential function
118-
and a simple L2 norm for the regularizer with learnable weight
119-
:math:`\lambda`.
115+
The given implementation uses a ResNet-like model for the potential
116+
function and a simple :math:`L_2` norm :math:`R(\mathbf{y}) =
117+
\textrm{softplus}(\lambda) \cdot {\| \mathbf{y} \|}^2_2` for the
118+
regularizer with learnable weight :math:`\lambda`.
120119
121120
Args:
122-
in_channels (int): The number of channels in the input to the layer.
123-
out_channels (float): The number of channels in the ouput.
124-
num_layers (List[int): A list of the number of hidden units in the
125-
potential function.
121+
in_channels (int): Size of each input sample.
122+
out_channels (int): Size of each output sample.
123+
num_layers (List[int): List of hidden channels in the potential
124+
function.
126125
grad_iter (int): The number of steps to take in the internal gradient
127126
descent. (default: :obj:`5`)
128-
lamb (float): The initial regularization constant. Is learnable.
129-
descent. (default: :obj:`0.1`)
127+
lamb (float): The initial regularization constant.
128+
(default: :obj:`0.1`)
130129
"""
131130
def __init__(self, in_channels: int, out_channels: int,
132131
num_layers: List[int], grad_iter: int = 5, lamb: float = 0.1):
@@ -135,29 +134,25 @@ def __init__(self, in_channels: int, out_channels: int,
135134
self.potential = ResNetPotential(in_channels + out_channels, 1,
136135
num_layers)
137136
self.optimizer = MomentumOptimizer()
138-
self._initial_lambda = lamb
139-
self._labmda = torch.nn.Parameter(Tensor([lamb]), requires_grad=True)
137+
self.initial_lamb = lamb
138+
self.lamb = torch.nn.Parameter(Tensor(1), requires_grad=True)
140139
self.softplus = torch.nn.Softplus()
141140
self.grad_iter = grad_iter
142141
self.output_dim = out_channels
143142
self.reset_parameters()
144143

145144
def reset_parameters(self):
146-
self.lamb.data.fill_(self._initial_lambda)
145+
self.lamb.data.fill_(self.initial_lamb)
147146
reset(self.optimizer)
148147
reset(self.potential)
149148

150-
@property
151-
def lamb(self):
152-
return self.softplus(self._labmda)
153-
154149
def init_output(self, index: Optional[Tensor] = None) -> Tensor:
155150
index_size = 1 if index is None else int(index.max().item() + 1)
156151
return torch.zeros(index_size, self.output_dim,
157152
requires_grad=True).float()
158153

159-
def reg(self, y: Tensor) -> float:
160-
return self.lamb * y.square().mean(dim=-2).sum(dim=0)
154+
def reg(self, y: Tensor) -> Tensor:
155+
return self.softplus(self.lamb) * y.square().sum(dim=-1).mean()
161156

162157
def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor]):
163158
return self.potential(x, y, index) + self.reg(y)

torch_geometric/nn/aggr/gmt.py

-8
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,6 @@ class GraphMultisetTransformer(Aggregation):
171171
(default: :obj:`4`)
172172
layer_norm (bool, optional): If set to :obj:`True`, will make use of
173173
layer normalization. (default: :obj:`False`)
174-
175-
Shapes:
176-
- **input:**
177-
node features :math:`(|\mathcal{V}|, F_{in})`,
178-
batch vector :math:`(|\mathcal{V}|)`,
179-
edge indices :math:`(2, |\mathcal{E}|)` *(optional)*
180-
- **output:** graph features :math:`(|\mathcal{G}|, F_{out})` where
181-
:math:`|\mathcal{G}|` denotes the number of graphs in the batch
182174
"""
183175
def __init__(
184176
self,

torch_geometric/nn/aggr/lstm.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88

99
class LSTMAggregation(Aggregation):
1010
r"""Performs LSTM-style aggregation in which the elements to aggregate are
11-
interpreted as a sequence.
11+
interpreted as a sequence, as described in the `"Inductive Representation
12+
Learning on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper.
1213
1314
.. warning::
1415
:class:`LSTMAggregation` is not a permutation-invariant operator.
1516
16-
.. note::
17-
:class:`LSTMAggregation` requires sorted indices as input.
18-
1917
Args:
2018
in_channels (int): Size of each input sample.
2119
out_channels (int): Size of each output sample.

torch_geometric/nn/aggr/multi.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111
class MultiAggregation(Aggregation):
1212
r"""Performs aggregations with one or more aggregators and combines
13-
aggregated results.
13+
aggregated results, as described in the `"Principal Neighbourhood
14+
Aggregation for Graph Nets" <https://arxiv.org/abs/2004.05718>`_ and
15+
`"Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions"
16+
<https://arxiv.org/abs/2104.01481>`_ papers.
1417
1518
Args:
1619
aggrs (list): The list of aggregation schemes to use.

0 commit comments

Comments
 (0)