Brak opisu

schema_to_dataclass.py 3.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #!/usr/bin/env python
  2. import json
  3. import re
  4. import sys
  5. from pathlib import Path
  6. map_schema_type_to_python = {
  7. "object": "Dict",
  8. "array": "List",
  9. "integer": "int",
  10. "string": "str",
  11. "number": "int",
  12. "boolean": "bool",
  13. "any": "Any",
  14. }
  15. def create_dataclass(name):
  16. return dataclass(name)
  17. def create_attribute(name, type, required):
  18. name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
  19. type = map_schema_type_to_python[type]
  20. return attribute(name, type, required)
  21. class dataclass:
  22. def __init__(self, name):
  23. self.name = name
  24. self.attrs = []
  25. def add_attr(self, attr):
  26. self.attrs.append(attr)
  27. def __str__(self):
  28. output = f"@dataclass\nclass {self.name}Payload:\n"
  29. if len(self.attrs) == 0:
  30. return output + " pass\n"
  31. optional_attrs = ""
  32. for attr in self.attrs:
  33. if attr.required:
  34. output += str(attr)
  35. else:
  36. optional_attrs += str(attr)
  37. return output + optional_attrs
  38. class attribute:
  39. def __init__(self, name, type, required):
  40. self.name = name
  41. self.type = type
  42. self.required = required
  43. def __str__(self):
  44. name = self.name
  45. if not re.match("^[a-zA-Z_]", self.name):
  46. name = "_" + self.name
  47. definition = f" {name}: {self.type}"
  48. if self.required is True:
  49. definition += "\n"
  50. else:
  51. definition += " = None\n"
  52. return definition
  53. def __repr__(self):
  54. return f"<{self.name}, {self.type}, {self.required}> "
  55. calls = []
  56. call_results = []
  57. def parse_schema(schema):
  58. with open(schema, "r") as f:
  59. schema = json.loads(f.read())
  60. name = schema["$id"].split(":")[-1]
  61. call = False
  62. call_result = False
  63. if name.endswith("Request"):
  64. call = True
  65. name = name[: -len("Request")]
  66. elif name.endswith("Response"):
  67. call_result = True
  68. name = name[: -len("Response")]
  69. dc = create_dataclass(name)
  70. try:
  71. properties = schema["properties"]
  72. except KeyError:
  73. if call:
  74. calls.append(dc)
  75. elif call_result:
  76. call_results.append(dc)
  77. return
  78. for property, definition in properties.items():
  79. if property == "customData":
  80. continue
  81. required = True
  82. try:
  83. required = property in schema["required"]
  84. except KeyError:
  85. required = False
  86. try:
  87. type = definition["type"]
  88. except KeyError:
  89. try:
  90. ref = definition["$ref"].split("/")[-1]
  91. type = schema["definitions"][ref]["type"]
  92. except KeyError:
  93. if definition == {}:
  94. type = "any"
  95. attr = create_attribute(property, type, required)
  96. dc.add_attr(attr)
  97. if call:
  98. calls.append(dc)
  99. elif call_result:
  100. call_results.append(dc)
  101. if __name__ == "__main__":
  102. if len(sys.argv) != 2:
  103. print("Pass path to folder with schemas")
  104. sys.exit(-1)
  105. p = Path(sys.argv[1])
  106. schemas = list(p.glob("*.json"))
  107. for schema in schemas:
  108. parse_schema(schema)
  109. with open("call.py", "wb+") as f:
  110. f.write(b"from typing import Any, Dict, List\n")
  111. f.write(b"from dataclasses import dataclass, field, Optional\n")
  112. for call in sorted(calls, key=lambda call: call.name):
  113. f.write(b"\n\n")
  114. f.write(str(call).encode("utf-8"))
  115. with open("call_result.py", "wb+") as f:
  116. f.write(b"from typing import Any, Dict, List\n")
  117. f.write(b"from dataclasses import dataclass, field\n")
  118. for call in sorted(call_results, key=lambda call: call.name):
  119. f.write(b"\n\n")
  120. f.write(str(call).encode("utf-8"))