Skip to content
Snippets Groups Projects
Commit d0a6adf8 authored by Tomáš Orviský's avatar Tomáš Orviský
Browse files

Changed code formatting, snake_case, type safety and minor refactorings

parent 2912e458
No related merge requests found
from manim import *
from typing import Type
from pyPlantUML import *
import argparse
from pyPlantUML.layout.KamadaKawai import KamadaKawai
from pyPlantUML.layout.DotLayout import DotLayout
from pyPlantUML.layout.SpringLayout import SpringLayout
class MainScene(MovingCameraScene):
file: str
animate: bool
def construct(self):
parser = PUMLParser()
diagram: Type[Diagram] = parser.parseFile(self.file)
diagram: Diagram = parser.parse_file(self.file)
layout = SpringLayout(diagram)
layout.apply()
......@@ -21,16 +22,16 @@ class MainScene(MovingCameraScene):
self.camera.background_color = WHITE
Text.set_default(font_size=16)
diagram.setScene(self)
diagram.set_scene(self)
diagram.animate = self.animate
diagram.draw()
def setFile(self, file: str):
def set_file(self, file: str):
self.file = file
def setAnimate(self, animate: bool):
def set_animate(self, animate: bool):
self.animate = animate
......@@ -50,7 +51,7 @@ if __name__ == "__main__":
scene = MainScene()
scene.setFile(args.file)
scene.setAnimate(args.animate)
scene.set_file(args.file)
scene.set_animate(args.animate)
scene.render()
......@@ -8,11 +8,12 @@ class AttributeModifier(Enum):
PROTECTED = "#"
PUBLIC = "+"
def fromString(string: str):
@staticmethod
def from_string(string: str):
attrMap = {member.value: member for member in AttributeModifier}
attr_map = {member.value: member for member in AttributeModifier}
result = attrMap.get(string)
result = attr_map.get(string)
if result is None:
result = AttributeModifier.NONE
......@@ -22,7 +23,7 @@ class AttributeModifier(Enum):
class ClassAttribute:
def __init__(self, isMethod: bool, modifier: AttributeModifier, text: str):
self.isMethod = isMethod
def __init__(self, is_method: bool, modifier: AttributeModifier, text: str):
self.isMethod = is_method
self.modifier = modifier
self.text = text
......@@ -2,9 +2,11 @@ from .DiagramObject import DiagramObject
from manim import *
class Diagram():
class Diagram:
def __init__(self, name: str):
self.scene = None
self.name = name
self.objects: typing.Dict[str, DiagramObject] = {}
self.animate = False
......@@ -12,36 +14,36 @@ class Diagram():
def draw(self):
for name, obj in self.objects.items():
self.drawObject(obj)
self.draw_object(obj)
for name, obj in self.objects.items():
for i, edge in enumerate(obj.edges):
if self.animate:
self.scene.play(Create(edge.draw(obj)))
self.scene.play(Create(edge.draw()))
else:
self.scene.add(edge.draw(obj))
self.scene.add(edge.draw())
def drawObject(self, obj: DiagramObject):
def draw_object(self, obj: DiagramObject):
if obj.mobject is None:
mobj = obj.draw()
mobject = obj.draw()
mobj.to_edge(UP)
mobject.to_edge(UP)
mobj.shift(RIGHT * obj.x)
mobj.shift(DOWN * obj.y)
mobject.shift(RIGHT * obj.x)
mobject.shift(DOWN * obj.y)
if self.animate:
self.scene.play(Create(mobj))
self.scene.play(Create(mobject))
self.scene.play(self.scene.camera.auto_zoom(
self.scene.mobjects, margin=0.5))
else:
self.scene.add(mobj)
self.scene.add(mobject)
self.scene.camera.auto_zoom(
self.scene.mobjects, margin=0.5, animate=False)
def addObject(self, obj: DiagramObject):
def add_object(self, obj: DiagramObject):
if obj.name not in self.objects:
self.objects[obj.name] = obj
......@@ -50,18 +52,14 @@ class Diagram():
if edge.target.name in self.objects:
edge.target = self.objects[edge.target.name]
else:
self.addObject(edge.target)
self.add_object(edge.target)
else:
for edge in obj.edges:
if edge.target not in self.objects:
self.addObject(edge.target)
self.add_object(edge.target)
edge.target = self.objects[edge.target.name]
self.objects[obj.name].edges += obj.edges
def rangeAroundZero(self, n):
half_n = n // 2
return list(range(-half_n, half_n + 1))
def setScene(self, scene: Scene):
def set_scene(self, scene: Scene):
self.scene = scene
......@@ -2,16 +2,14 @@ from .DiagramObject import DiagramObject
from .DiagramEdge import DiagramEdge
from .ClassAttribute import ClassAttribute
from typing import List
from manim import *
class DiagramClass(DiagramObject):
def __init__(self, name: str, type: str):
def __init__(self, name: str, class_type: str):
super().__init__(name)
self.type = type
self.type = class_type
self.edges: List[DiagramEdge] = []
self.attributes: List[ClassAttribute] = []
self.methods: List[ClassAttribute] = []
......@@ -21,57 +19,57 @@ class DiagramClass(DiagramObject):
header = Rectangle(color=GRAY)
text = Text(self.name, color=BLACK)
header.surround(text)
headGroup = VGroup(header, text)
head_group = VGroup(header, text)
attrBody = Rectangle(color=GRAY, height=0.2, width=0.2)
attrGroup = VGroup(attrBody)
attr_body = Rectangle(color=GRAY, height=0.2, width=0.2)
attr_group = VGroup(attr_body)
if len(self.attributes) != 0:
attrs = VGroup()
for attr in self.attributes:
textGroup = VGroup(
text_group = VGroup(
Text(attr.modifier.value, color=BLACK).scale(0.75),
Text(attr.text, color=BLACK).scale(0.75)
)
textGroup.arrange(RIGHT, buff=0.)
attrs.add(textGroup)
text_group.arrange(RIGHT, buff=0.)
attrs.add(text_group)
attrs.arrange(DOWN, buff=0.1)
attrGroup.add(attrs)
attrBody.surround(attrs, buff=0.2)
attrBody.stretch_to_fit_height(attrs.height + 0.1)
attr_group.add(attrs)
attr_body.surround(attrs, buff=0.2)
attr_body.stretch_to_fit_height(attrs.height + 0.1)
methodBody = Rectangle(color=GRAY, height=0.2, width=0.2)
methodGroup = VGroup(methodBody)
method_body = Rectangle(color=GRAY, height=0.2, width=0.2)
method_group = VGroup(method_body)
if len(self.methods) != 0:
methods = VGroup()
for method in self.methods:
textGroup = VGroup(
text_group = VGroup(
Text(method.modifier.value, color=BLACK).scale(0.75),
Text(method.text, color=BLACK).scale(0.75)
)
textGroup.arrange(RIGHT, buff=0.1)
methods.add(textGroup)
text_group.arrange(RIGHT, buff=0.1)
methods.add(text_group)
methods.arrange(DOWN, buff=0.1)
methodGroup.add(methods)
methodBody.surround(methods, buff=0.2)
methodBody.stretch_to_fit_height(methods.height + 0.1)
method_group.add(methods)
method_body.surround(methods, buff=0.2)
method_body.stretch_to_fit_height(methods.height + 0.1)
maxWidth = max(headGroup.width, attrGroup.width, methodGroup.width)
max_width = max(head_group.width, attr_group.width, method_group.width)
header.stretch_to_fit_width(maxWidth)
attrBody.stretch_to_fit_width(maxWidth)
methodBody.stretch_to_fit_width(maxWidth)
header.stretch_to_fit_width(max_width)
attr_body.stretch_to_fit_width(max_width)
method_body.stretch_to_fit_width(max_width)
attrGroup.next_to(headGroup, DOWN, buff=0)
methodGroup.next_to(attrGroup, DOWN, buff=0)
attr_group.next_to(head_group, DOWN, buff=0)
method_group.next_to(attr_group, DOWN, buff=0)
self.mobject = VGroup(headGroup, attrGroup, methodGroup)
self.mobject = VGroup(head_group, attr_group, method_group)
return self.mobject
def addEdge(self, edge: DiagramEdge):
def add_edge(self, edge: DiagramEdge):
self.edges.append(edge)
......@@ -6,32 +6,40 @@ from manim import *
class DiagramEdge(DiagramObject):
def __init__(self, name: str, target: DiagramObject, dashed: bool, size: int, sourceArrowType: str, targetArrowType: str):
def __init__(self,
name: str,
source: DiagramObject,
target: DiagramObject,
dashed: bool,
size: int,
source_arrow_type: Relation,
target_arrow_type: Relation):
DiagramObject.__init__(self, name)
self.source = source
self.target = target
self.dashed = dashed
self.size = size
self.sourceArrowType = sourceArrowType
self.targetArrowType = targetArrowType
self.sourceArrowType = source_arrow_type
self.targetArrowType = target_arrow_type
def draw(self, source: DiagramObject):
def draw(self):
start = source.mobject.get_top()
start = self.source.mobject.get_top()
target = self.target.mobject.get_bottom()
if source.mobject.get_top()[1] > self.target.mobject.get_bottom()[1]:
start = source.mobject.get_bottom()
if self.source.mobject.get_top()[1] > self.target.mobject.get_bottom()[1]:
start = self.source.mobject.get_bottom()
target = self.target.mobject.get_top()
startCenter = source.mobject.get_center()
targetCenter = self.target.mobject.get_center()
start_center = self.source.mobject.get_center()
target_center = self.target.mobject.get_center()
if source.y == self.target.y:
if startCenter[0] < targetCenter[0]:
start = source.mobject.get_right()
if self.source.y == self.target.y:
if start_center[0] < target_center[0]:
start = self.source.mobject.get_right()
target = self.target.mobject.get_left()
else:
start = source.mobject.get_left()
start = self.source.mobject.get_left()
target = self.target.mobject.get_right()
line = DashedLine(start, target, buff=0, stroke_width=1, tip_length=0.25) if self.dashed else Line(
......@@ -39,13 +47,13 @@ class DiagramEdge(DiagramObject):
line.color = BLACK
line.add_tip(self.getLineTip())
line.add_tip(self.get_line_tip())
self.mobject = line
return self.mobject
def getLineTip(self):
def get_line_tip(self):
if self.targetArrowType == Relation.EXTENSION:
......
......@@ -9,7 +9,7 @@ class DiagramObject(ABC):
def __init__(self, name: str):
self.name = name
self.mobject: Mobject = None
self.mobject: Mobject | None = None
self.x = 0
self.y = 0
......
......@@ -37,105 +37,109 @@ class PUMLParser(object):
relation : IDENTIFIER REL LINE IDENTIFIER
"""
leftClassName = str(p[1])
left_class_name = str(p[1])
relation = p[2]
lineData = p[3]
rightClassName = str(p[4])
line_data = p[3]
right_class_name = str(p[4])
leftClass = DiagramClass(leftClassName, 'class')
rightClass = DiagramClass(rightClassName, 'class')
left_class = DiagramClass(left_class_name, 'class')
right_class = DiagramClass(right_class_name, 'class')
line = DiagramEdge(
leftClassName + "-" + relation + "-" + rightClassName,
leftClass,
lineData[1],
lineData[0],
left_class_name + "-" + relation + "-" + right_class_name,
right_class,
left_class,
line_data[1],
line_data[0],
Relation["NONE"],
Relation[relation],
)
rightClass.addEdge(line)
right_class.add_edge(line)
self.diagram.addObject(leftClass)
self.diagram.addObject(rightClass)
self.diagram.add_object(left_class)
self.diagram.add_object(right_class)
def p_right_relation(self, p):
"""
relation : IDENTIFIER LINE REL IDENTIFIER
"""
leftClassName = str(p[1])
lineData = p[2]
left_class_name = str(p[1])
line_data = p[2]
relation = p[3]
rightClassName = str(p[4])
right_class_name = str(p[4])
leftClass = DiagramClass(leftClassName, 'class')
rightClass = DiagramClass(rightClassName, 'class')
left_class = DiagramClass(left_class_name, 'class')
right_class = DiagramClass(right_class_name, 'class')
line = DiagramEdge(
leftClassName + "-" + relation + "-" + rightClassName,
rightClass,
lineData[1],
lineData[0],
left_class_name + "-" + relation + "-" + right_class_name,
left_class,
right_class,
line_data[1],
line_data[0],
Relation["NONE"],
Relation[relation],
)
leftClass.addEdge(line)
left_class.add_edge(line)
self.diagram.addObject(leftClass)
self.diagram.addObject(rightClass)
self.diagram.add_object(left_class)
self.diagram.add_object(right_class)
def p_simple_relation(self, p):
"""
relation : IDENTIFIER LINE IDENTIFIER
"""
leftClassName = str(p[1])
lineData = p[2]
rightClassName = str(p[3])
left_class_name = str(p[1])
line_data = p[2]
right_class_name = str(p[3])
leftClass = DiagramClass(leftClassName, 'class')
rightClass = DiagramClass(rightClassName, 'class')
left_class = DiagramClass(left_class_name, 'class')
right_class = DiagramClass(right_class_name, 'class')
line = DiagramEdge(
leftClassName + "-" + rightClassName,
leftClass,
lineData[1],
lineData[0],
left_class_name + "-" + right_class_name,
right_class,
left_class,
line_data[1],
line_data[0],
Relation["NONE"],
Relation["NONE"],
)
rightClass.addEdge(leftClass)
right_class.add_edge(line)
self.diagram.addObject(leftClass)
self.diagram.addObject(rightClass)
self.diagram.add_object(left_class)
self.diagram.add_object(right_class)
def p_bi_relation(self, p):
"""
relation : IDENTIFIER REL LINE REL IDENTIFIER
"""
leftClassName = str(p[1])
leftRelation = p[2]
lineData = p[3]
rightRelation = p[4]
rightClassName = str(p[5])
left_class_name = str(p[1])
left_relation = p[2]
line_data = p[3]
right_relation = p[4]
right_class_name = str(p[5])
leftClass = DiagramClass(leftClassName, 'class')
rightClass = DiagramClass(rightClassName, 'class')
left_class = DiagramClass(left_class_name, 'class')
right_class = DiagramClass(right_class_name, 'class')
line = DiagramEdge(
leftClassName + "-" + leftRelation + "-" + rightRelation + "-" + rightClassName,
leftClass,
lineData[1],
lineData[0],
Relation[leftRelation],
Relation[rightRelation],
left_class_name + "-" + left_relation + "-" + right_relation + "-" + right_class_name,
right_class,
left_class,
line_data[1],
line_data[0],
Relation[left_relation],
Relation[right_relation],
)
rightClass.addEdge(line)
right_class.add_edge(line)
self.diagram.addObject(leftClass)
self.diagram.addObject(rightClass)
self.diagram.add_object(left_class)
self.diagram.add_object(right_class)
def p_class(self, p):
"""
......@@ -161,15 +165,15 @@ class PUMLParser(object):
| ABS_CLASS CLASS STRING
"""
classType = str(p[1]).lower()
class_type = str(p[1]).lower()
name = str(p[2])
if classType == "abstract":
classType = "abstract_class"
if class_type == "abstract":
class_type = "abstract_class"
name = str(p[3])
classObj = DiagramClass(name, classType)
class_obj = DiagramClass(name, class_type)
self.diagram.addObject(classObj)
self.diagram.add_object(class_obj)
def p_class_attr(self, p):
"""
......@@ -177,24 +181,25 @@ class PUMLParser(object):
class_attr : STRING AFTERCOLON
"""
o = self.diagram.objects[p[1]]
attrStr = str(p[2])
isMethod = '(' in attrStr
o: DiagramClass = self.diagram.objects[p[1]]
attr_str = str(p[2])
is_method = '(' in attr_str
if attrStr[0] in ['-', '~', '#', '+']:
if attr_str[0] in ['-', '~', '#', '+']:
attribute = ClassAttribute(
isMethod, AttributeModifier.fromString(attrStr[0]), attrStr[1:])
is_method, AttributeModifier.from_string(attr_str[0]), attr_str[1:])
else:
attrStr = attrStr.strip()
attr_str = attr_str.strip()
attribute = ClassAttribute(
isMethod, AttributeModifier.NONE, attrStr)
is_method, AttributeModifier.NONE, attr_str)
if isMethod:
if is_method:
o.methods.append(attribute)
else:
o.attributes.append(attribute)
def p_error(self, p):
@staticmethod
def p_error(p):
print("Parser syntax error:")
print("\t", p)
......@@ -207,7 +212,7 @@ class PUMLParser(object):
def parse(self, text) -> Diagram:
return self.parser.parse(text)
def parseFile(self, path):
def parse_file(self, path):
with open(path, 'r') as file:
text = file.read()
......
import ply.lex as lex
# noinspection PyPep8Naming
class PUMLexer(object):
keywords = {
"@startuml": "START",
"@enduml": "END",
......@@ -19,86 +19,98 @@ class PUMLexer(object):
}
tokens = [
"EXTENSION",
"ASSOCIATION",
"COMPOSITION",
"AGGREGATION",
"REL",
"STRING",
"IDENTIFIER",
"LINE",
"AFTERCOLON",
] + list(keywords.values())
def t_LINE(self, t):
r'[-\.]+'
"EXTENSION",
"ASSOCIATION",
"COMPOSITION",
"AGGREGATION",
"REL",
"STRING",
"IDENTIFIER",
"LINE",
"AFTERCOLON",
] + list(keywords.values())
@staticmethod
def t_LINE(t):
r"""[-\.]+"""
t.value = (int(t.value.count(".") + t.value.count("-")), "." in t.value)
return t
def t_STRING(self, t):
r'"(.*?)"'
@staticmethod
def t_STRING(t):
r""""(.*?)\""""
t.value = t.value.replace("\"", "")
return t
def t_EXTENSION(self, t):
r'<\||\|>|\^'
@staticmethod
def t_EXTENSION(t):
r"""<\||\|>|\^"""
t.type = 'REL'
t.value = 'EXTENSION'
return t
def t_ASSOCIATION(self, t):
r'<|>'
@staticmethod
def t_ASSOCIATION(t):
r"""<|>"""
t.type = 'REL'
t.value = 'ASSOCIATION'
return t
def t_AGGREGATION(self, t):
r'o'
@staticmethod
def t_AGGREGATION(t):
r"""o"""
t.type = 'REL'
t.value = 'AGGREGATION'
return t
def t_COMPOSITION(self, t):
r'\*'
@staticmethod
def t_COMPOSITION(t):
r"""\*"""
t.type = 'REL'
t.value = 'COMPOSITION'
return t
def t_HASH(self, t):
r'\#'
@staticmethod
def t_HASH(t):
r"""\#"""
t.type = 'REL'
t.value = 'HASH'
return t
def t_CROSS(self, t):
r'x'
@staticmethod
def t_CROSS(t):
r"""x"""
t.type = 'REL'
t.value = 'CROSS'
return t
def t_CROW_FOOT(self, t):
r'\{|\}'
@staticmethod
def t_CROW_FOOT(t):
r"""\{|\}"""
t.type = 'REL'
t.value = 'CROW_FOOT'
return t
def t_NEST_CLASSIFIER(self, t):
r'\+'
@staticmethod
def t_NEST_CLASSIFIER(t):
r"""\+"""
t.type = 'REL'
t.value = 'NEST_CLASSIFIER'
return t
def t_IDENTIFIER(self, t):
r'@*\w+[()]*'
r"""@*\w+[()]*"""
t.type = self.keywords.get(t.value.lower(), 'IDENTIFIER')
return t
def t_AFTERCOLON(self, t):
r':.+'
@staticmethod
def t_AFTERCOLON(t):
r""":.+"""
t.value = t.value[1:].strip()
return t
def t_error(self, t):
@staticmethod
def t_error(t):
print("Illegal character: '{}'".format(t.value[0]))
t.lexer.skip(1)
......
......@@ -3,6 +3,7 @@ import networkx as nx
from .. import Diagram
class DiagramLayout(ABC):
def __init__(self, diagram: Diagram) -> None:
......@@ -12,16 +13,16 @@ class DiagramLayout(ABC):
def apply(self) -> None:
pass
def getGraph(self) -> nx.DiGraph:
G = nx.DiGraph()
def get_graph(self) -> nx.DiGraph:
g = nx.DiGraph()
for name, obj in self.diagram.objects.items():
G.add_node(name)
g.add_node(name)
for e in obj.edges:
G.add_edge(name, e.target.name)
return G
g.add_edge(name, e.target.name)
return g
def scale(self, x: float, y: float) -> None:
for name, obj in self.diagram.objects.items():
......
from networkx.drawing.nx_agraph import graphviz_layout
from .DiagramLayout import DiagramLayout
from .. import Diagram
class DotLayout(DiagramLayout):
def __init__(self, diagram: Diagram) -> None:
self.diagram = diagram
class DotLayout(DiagramLayout):
def apply(self) -> None:
layout = graphviz_layout(self.getGraph(), prog='dot')
layout = graphviz_layout(self.get_graph(), prog='dot')
for key, pos in layout.items():
self.diagram.objects[key].x = pos[0]
......
import networkx as nx
from .DiagramLayout import DiagramLayout
from .. import Diagram
class KamadaKawai(DiagramLayout):
def __init__(self, diagram: Diagram) -> None:
self.diagram = diagram
class KamadaKawai(DiagramLayout):
def apply(self) -> None:
layout = nx.kamada_kawai_layout(self.getGraph())
layout = nx.kamada_kawai_layout(self.get_graph())
for key, pos in layout.items():
self.diagram.objects[key].x = pos[0]
......
import networkx as nx
from .DiagramLayout import DiagramLayout
from .. import Diagram
class SpringLayout(DiagramLayout):
def __init__(self, diagram: Diagram) -> None:
self.diagram = diagram
class SpringLayout(DiagramLayout):
def apply(self) -> None:
layout = nx.spring_layout(self.getGraph())
layout = nx.spring_layout(self.get_graph())
for key, pos in layout.items():
self.diagram.objects[key].x = pos[0]
......
......@@ -6,36 +6,36 @@ class PyPlantUMLTest(unittest.TestCase):
def test_input_elements(self):
p = PUMLParser()
d = p.parseFile("../examples/01_elements.puml")
d = p.parse_file("../examples/01_elements.puml")
self.assertIsNotNone(d)
def test_class_tree(self):
p = PUMLParser()
d = p.parseFile("../examples/02_class_tree.puml")
d = p.parse_file("../examples/02_class_tree.puml")
self.assertIsNotNone(d)
def test_relations(self):
p = PUMLParser()
d = p.parseFile("../examples/03_relations.puml")
d = p.parse_file("../examples/03_relations.puml")
self.assertIsNotNone(d)
def test_relations_extra(self):
p = PUMLParser()
d = p.parseFile("../examples/04_relations_extra.puml")
d = p.parse_file("../examples/04_relations_extra.puml")
self.assertIsNotNone(d)
def test_class_attributes(self):
p = PUMLParser()
d = p.parseFile("../examples/05_class_attributes.puml")
d = p.parse_file("../examples/05_class_attributes.puml")
self.assertIsNotNone(d)
def test_class_tree_with_attributes(self):
p = PUMLParser()
d = p.parseFile("../examples/06_class_tree_with_attributes.puml")
d = p.parse_file("../examples/06_class_tree_with_attributes.puml")
self.assertIsNotNone(d)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment