Skip to content

nn.Module subclasses don't typecheck due to lack of contravariant input types #35566

Closed
@pmeier

Description

@pmeier

❓ Questions and Help

I've asked about this in the discussion forum, but got no answer so far. Although it not has been up for that long, I think other users of PyTorch won't be able to help me out. Thus, I'm asking this here. If you disagree, please tell me how to proceed.


I'm trying to annotate subclasses of nn.Module inline in python3.6, but for now I unable to get it to work. For my project I create an abstract subclass of nn.Module

from typing import Any
from torch import nn


class Foo(nn.Module):
    def forward(self, *input: Any, **kwargs: Any) -> Any:
        pass

Running mypy on this succeeds. In a second step I add a more concrete class

import torch


class Bar(Foo):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        pass

Now mypy errors with

error: Signature of "forward" incompatible with supertype "Foo"
error: Signature of "forward" incompatible with supertype "Module"

while pointing to forward() of Bar both times. I'm aware that this is valid error since Bar violates the Liskov Substitution Principle. I don't think it is intended for every subclass of nn.Module to have that exact signature, but I'm puzzled how to get around this.

Looking at the stub of nn.Module

# We parameter modules by the return type of its `forward` (and therefore `__call__`) method. This allows
# type inference to infer that the return value of calling a module in the canonical way (via `__call__)` is the
# same as the custom `forward` function of the submodule. Submodules that wish to opt in this functionality be
# defined as eg class ReturnsTwoTensors(Module[Tuple[Tensor, Tensor]]): ...

I think this paragraph is related, but for now I was unable to comprehend it. Can you provide me with a minimal example how to do this?

Environment

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: Could not collect

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: GeForce GTX 1080
Nvidia driver version: 440.33.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.18.2
[pip3] torch==1.4.0
[pip3] torchvision==0.5.0
[conda] Could not collect

Extendend Environment

Package           Version
----------------- -------
mypy              0.770  
mypy-extensions   0.4.3  
typed-ast         1.4.1  
typing-extensions 3.7.4.1

cc @ezyang

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: typingRelated to mypy type annotationstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions