Skip to content

Commit

Permalink
Teach is_signature_compatible() to dig into similar annotations
Browse files Browse the repository at this point in the history
Summary: D68450007 updated some annotations in pytorch. This function wasn't correctly evaluating `typing.Dict[X, Y]` and `dict[X, Y]` as the equivalent.

Differential Revision: D68475380
  • Loading branch information
aorenste authored and facebook-github-bot committed Jan 22, 2025
1 parent dd5457c commit d8a6f7a
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion torchrec/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,32 @@
# pyre-strict

import inspect
import typing
from typing import Any


def _is_annot_compatible(prev: Any, curr: Any) -> bool:
if prev == curr:
return True

if not (prev_origin := typing.get_origin(prev)):
return False
if not (curr_origin := typing.get_origin(curr)):
return False

if prev_origin != curr_origin:
return False

prev_args = typing.get_args(prev)
curr_args = typing.get_args(curr)
if len(prev_args) != len(curr_args):
return False

for prev_arg, curr_arg in zip(prev_args, curr_args):
if not _is_annot_compatible(prev_arg, curr_arg):
return False

return True


def is_signature_compatible(
Expand Down Expand Up @@ -84,6 +110,8 @@ def is_signature_compatible(
return False

# TODO: Account for Union Types?
if current_signature.return_annotation != previous_signature.return_annotation:
if not _is_annot_compatible(
previous_signature.return_annotation, current_signature.return_annotation
):
return False
return True

0 comments on commit d8a6f7a

Please sign in to comment.