Skip to content

enums are not represented correctly in generated client stub

Consider the following file test.py where we have a function test_func that takes Enum values and has an Enum as default:

from demessaging import main, registry
from enum import Enum

__all__ = ["test_func"]


@registry.register_type
class TestEnum(str, Enum):

    TEST_VALUE = "test"


def test_func(param: TestEnum = TestEnum.TEST_VALUE) -> TestEnum:
    return param


if __name__ == "__main__":
    main(messaging_config=dict(topic="test-topic"))

if you generate the client stub via

python test.py generate > test_api.py

you will get

the following error
Traceback (most recent call last):
  File "test.py", line 18, in <module>
    main(topic="test-topic")
  File "/home/psommer/Documents/code/development/datahub/de-messaging-python/demessaging/backend/__init__.py", line 83, in main
    print(method(**method_kws))
  File "/home/psommer/Documents/code/development/datahub/de-messaging-python/demessaging/backend/module.py", line 378, in generate
    code = black.format_str(
  File "src/black/__init__.py", line 1154, in format_str
  File "src/black/__init__.py", line 1164, in _format_str_once
  File "src/black/parsing.py", line 128, in lib2to3_parse
black.parsing.InvalidInput: Cannot parse: 31:41: def test_func(param: __main__.TestEnum = <TestEnum.TEST_VALUE: 'test'>) -> __main__.TestEnum:

this is because the default value TestEnum.TEST_VALUE is rendered as <TestEnum.TEST_VALUE: 'test'> (which is the __repr__ of TestEnum.TEST_VALUE).

You can see this if you omit the formatters in the generation of the client stub via

python test.py generate --no-formatters > test_api.py

that will produce

the following client stub
"""

"""
from typing import Callable, Union
NoneType = type(None)

from demessaging.config import ModuleConfig
from demessaging import configure, main, BackendModule as _BackendModule




from demessaging import main
from demessaging import registry
from enum import Enum

__all__ = [
    "test_func"
]


class TestEnum(str, Enum):

    TEST_VALUE = "test"





def test_func(param: TestEnum = <TestEnum.TEST_VALUE: 'test'>) -> TestEnum:
    """
    
    """
    request = {
        "func_name": "test_func",
        "param": param
    }

    model = BackendModule.parse_obj(request)
    response = model.compute()

    return response.__root__  # type: ignore


backend_config = ModuleConfig.parse_raw("""
{
    "messaging_config": {
        "topic": "test-topic",
        "max_workers": null,
        "queue_size": null,
        "max_payload_size": 512000,
        "host": "localhost",
        "port": "8080",
        "persistent": "non-persistent",
        "tenant": "public",
        "namespace": "default"
}
}
""")

_creator: Callable
if __name__ == "__main__":
    _creator = main
else:
    _creator = _BackendModule.create_model

BackendModule = _creator(
    __name__,
    config=backend_config,
    class_name="BackendModule",
    members=[
        test_func
    ]
)

One way to circument this is by replacing the __repr__ method of the Enum subclass, i.e. use something like

@registry.register_type
class TestEnum(str, Enum):

    TEST_VALUE = "test"

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}.{self.name}"

this will generate

the desired client stub
"""

"""
from enum import Enum
from typing import Callable

from demessaging import BackendModule as _BackendModule
from demessaging import main
from demessaging.config import ModuleConfig

NoneType = type(None)


__all__ = ["test_func"]


class TestEnum(str, Enum):

    TEST_VALUE = "test"

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}.{self.name}"


def test_func(
    param: TestEnum = TestEnum.TEST_VALUE,
) -> TestEnum:
    """ """
    request = {"func_name": "test_func", "param": param}

    model = BackendModule.parse_obj(request)
    response = model.compute()

    return response.__root__  # type: ignore


backend_config = ModuleConfig.parse_raw(
    """
{
    "messaging_config": {
        "topic": "test-topic",
        "max_workers": null,
        "queue_size": null,
        "max_payload_size": 512000,
        "host": "localhost",
        "port": "8080",
        "persistent": "non-persistent",
        "tenant": "public",
        "namespace": "default"
}
}
"""
)

_creator: Callable
if __name__ == "__main__":
    _creator = main
else:
    _creator = _BackendModule.create_model

BackendModule = _creator(
    __name__,
    config=backend_config,
    class_name="BackendModule",
    members=[test_func],
)