From 5b52791b73c740bf94fc5e45c3ebbc77a440351c Mon Sep 17 00:00:00 2001
From: Mesharo <Hecko97@seznam.cz>
Date: Wed, 18 Dec 2024 09:00:33 +0100
Subject: [PATCH] linkovani, aktualni stav projektu

---
 main.py  | 137 ++++++++++++++++++++++++++++++++++---------------------
 tests.py | 121 +++++++++++++++++++++++++++---------------------
 2 files changed, 154 insertions(+), 104 deletions(-)

diff --git a/main.py b/main.py
index bf25ab7..2d3e0e5 100644
--- a/main.py
+++ b/main.py
@@ -94,59 +94,101 @@ def erase_html(code_section: list) -> str:
 
     return replace_unrecognized_letters(result)
 
+def can_be_parsed(code: str) -> bool:
+    try:
+        expression_tree = sqlglot.parse(code, dialect='postgres')
+
+        if expression_tree == [None]:
+            return False
+
+        return True
+    except:
+        return False
+
+def can_be_scoped(ast) -> bool:
+    try:
+        root = sqlglot.optimizer.scope.build_scope(ast)
+        return True
+    except:
+        return False
+
+def can_be_qualified(ast) -> bool:
+    try:
+        sqlglot.optimizer.qualify.qualify(ast)
+
+        return True
+    except:
+        return False
+
+def solve_with_scopes(ast) -> list:
+    result = []
+    columns = []
+    tables = []
+    aliases = []
+
+    root = sqlglot.optimizer.scope.build_scope(ast)   
+    for scope in root.traverse():
+        pass
+
+def solve_without_scopes(ast) -> list:
+    result = []
+    columns = []
+    tables = []
+    aliases = []
+
+    tmp = list(ast.find_all(sqlglot.exp.Column))
+    for column in tmp:
+        columns.append(str(column))            
+
+    tmp = list(ast.find_all(sqlglot.exp.Table))
+    for table in tmp:
+        tables.append(str(table))
+
+    tmp = list(ast.find_all(sqlglot.exp.Alias))
+    for alias in tmp:
+        aliases.append(str(alias))
+
+    print(f'Cols: {columns}\nTabs: {tables}\nAliases: {aliases}')
+
+    #TODO traversing again using walk and link together
+
+
 def analyze(id: str, codes: list) -> tuple:
-    parsed = 0
-    not_parsed = 0
-    is_none = 0
-    not_empty_columns_tables_aliases = 0
+    result = []
+    columns_tables = []
+
     for code_list in codes:
         all_codes_string = erase_html(code_list)
         for code in all_codes_string.split(';'):
-            #print('----------------------')
-            #print(code)
+            if not can_be_parsed(code):
+                continue
 
-            correct = False
             try:
                 expression_tree = sqlglot.parse(code, dialect='postgres')
-                if expression_tree == [None]:
-                    is_none += 1
-                    correct = True
-                    continue
                 
                 for ast in expression_tree:
-                    columns = ast.find_all(sqlglot.exp.Column)
-                    tables = ast.find_all(sqlglot.exp.Table)
-                    aliases = ast.find_all(sqlglot.exp.Alias)
+                    if (can_be_qualified(ast)):
+                        sqlglot.optimizer.qualify.qualify(ast)
+
+                    if (can_be_scoped(ast)):
+                        columns_tables = solve_with_scopes(ast)
                     
-                    if next(columns, None) is not None:
-                        not_empty_columns_tables_aliases += 1
-                        continue
-                    if next(tables, None) is not None:
-                        not_empty_columns_tables_aliases += 1
-                        continue
-                    if next(aliases, None) is not None:
-                        not_empty_columns_tables_aliases += 1
-                        continue
-
-                correct = True
+                    columns_tables = solve_without_scopes(ast)
+
+                    if columns_tables:
+                        result.append(columns_tables)
+                        columns_tables.clear()
+
             except sqlglot.errors.ParseError as pe:
-                correct = False
-                #print(f'----\nParseError: {pe}\n-----')
+                continue
             except sqlglot.errors.TokenError as te:
-                correct = False
-                #print(f'----\nTokenError: {te}\n-----')
+                continue
             except sqlglot.errors.OptimizeError as oe:
-                correct = False
-                #print(f'----\nOptimizeError: {oe}\n-----')
+                continue
             except:
-                correct = False
-
-            if correct:
-                parsed += 1
-            else:
-                not_parsed += 1   
+                continue
 
-    return (parsed, not_parsed, is_none, not_empty_columns_tables_aliases)
+    return (id, result)
 
 def run(input_filepath_all_answers: str, input_filepath_postgresql_questions: str, input_filepath_linked: str, input_filepath_codes: str) -> None:
     """Main function.
@@ -164,25 +206,16 @@ def run(input_filepath_all_answers: str, input_filepath_postgresql_questions: st
 
     if os.path.isfile(input_filepath_codes):
         codes = load_code_sections(input_filepath_codes)
-        parsed = 0
-        not_parsed = 0
-        is_none = 0
-        not_empty_columns_tables_aliases = 0
 
+        count = 0
         for key, values in codes.items():
             if not values:
                 continue
-
-            tmp = analyze(key, values)
-            parsed += tmp[0]
-            not_parsed += tmp[1]
-            is_none += tmp[2]
-            not_empty_columns_tables_aliases += tmp[3]
             
-        print(f'Parsed: {parsed}, not parsed: {not_parsed}, None: {is_none} (included in Parsed), Found col/table/alias: {not_empty_columns_tables_aliases}')
-        print('DONE!')
-        # new dataset: Parsed: 424400, not parsed: 201498, None: 80004 (included in Parsed)
-        # Parsed: 344396, not parsed: 201498, None: 80004 (not included in Parsed), Found col/table/alias: 280418
+            while count < 100:
+                tmp = analyze(key, values)
+                print(f'ID: {tmp[0]}, columns and tables: {tmp[1]}')
+
         return
 
     if os.path.isfile(input_filepath_linked):
diff --git a/tests.py b/tests.py
index c88b2f4..57cc7df 100644
--- a/tests.py
+++ b/tests.py
@@ -78,6 +78,7 @@ def testing2():
 
     print(mydict)
 
+import sqlglot.expressions
 import sqlglot.optimizer.scope
 import sqlparse
 def erasing_backslashes():
@@ -170,66 +171,82 @@ def tmp():
     except sqlglot.errors.OptimizeError:
         expression_tree = undo
 
+def solve_schema(schema: str) -> list:
+    pass
+
+def solve_subquery(subquery: str) -> list:
+    # Rekurze - ast
+    pass
+
 def whatever():
-    expr = "SELECT xd.hello FROM y AS xd"
+    #expr = "SELECT hello FROM a INNER JOIN b;"
+    expr = 'SELECT xd, (SELECT a FROM b) FROM c;'
+    #expr = "INSERT INTO first(xdd, xdd2) VALUES((SELECT x FROM third), 1);"
+    #expr = """WITH x AS (SELECT a FROM y) SELECT a FROM x"""
     try:
         expression_tree = sqlglot.parse(expr)
         print(repr(expression_tree))
                 
-        for tmp in expression_tree:
-            columns = tmp.find_all(exp.Column)
-            tables = tmp.find_all(exp.Table)
-            aliases = tmp.find_all(exp.Alias)
-
-            print('Columns!')
-            for column in columns:
-                print(column.name)
-
-            print('Tables!')
-            for table in tables:
-                print(table.name)
-
-            print('Aliases!')
-            for alias in aliases:
-                print(alias.name)
-        
-        """
-        for node in expression_tree.args['expressions']:
-            if isinstance(node, exp.Column):
-                if (node.args['this']):
-                    print(f'Column name: {node.args["this"]}')
-                if (node.args['table']):
-                    print(f'from table: {node.args["table"]}')
+        for ast in expression_tree:
+            """
+            tmp = list(ast.find_all(sqlglot.exp.Schema))
+            for column in tmp:
+                print(column)  
+
+            tmp = list(ast.find_all(sqlglot.exp.Column))
+            for column in tmp:
+                columns.append(str(column))            
+
+            tmp = list(ast.find_all(sqlglot.exp.Table))
+            for table in tmp:
+                tables.append(str(table))
+
+            tmp = list(ast.find_all(sqlglot.exp.Alias))
+            for alias in tmp:
+                aliases.append(str(alias))"""
+            
+            walk = []
+            sub_result = []
+            
+            tmp = list(ast.find_all(sqlglot.exp.Schema))
+            for schema in tmp:
+                pass
+                #sub_result.append(f'Schema:{solve_schema(schema)}')
+                #ast.args['schema'].replace(sqlglot.exp.Alias(this=sqlglot.exp.Identifier("SOLVED")))
+            #print('first (xdd, xdd2)' == 'first(xdd, xdd2)')
+
+            tmp = list(ast.find_all(sqlglot.exp.Subquery))
+            for subquery in tmp:
+                pass
+                #sub_result.append(f'Subquery:{solve_subquery(subquery)}')
+                #ast.args['subquery'].replace(sqlglot.exp.Alias(this=sqlglot.exp.Identifier("SOLVED")))
+
+            for node in ast.walk(bfs=False):
+                if isinstance(node, sqlglot.expressions.Table):
+                    walk.append(f'Table:{str(node)}')
+                    continue
+                if isinstance(node, sqlglot.expressions.Column):
+                    walk.append(f'Column:{str(node)}')
+                    continue
+                if isinstance(node, sqlglot.expressions.Join):
+                    walk.append(f'Join')
+                    continue
+                #if isinstance(node, sqlglot.expressions.Alias):
+                    #aliases.append(str(node))
+
+        print(expr)
+        print(f'Walk: {walk}')
+        #print(f'Cols: {columns}\nTabs: {tables}\nAliases: {aliases}')
             
-            if isinstance(node, exp.Alias):
-                if (node.args['this']):
-                    print(f'Column name: {node.args["this"]}')
-
-            if isinstance(node, exp.Table):
-                if (node.args['this']):
-                    print(f'Column name: {node.args["this"]}')
-                if (node.args['alias']):
-                    print(f'from table: {node.args["alias"]}')
-        """
 
     except sqlglot.errors.ParseError as pe:
-        print(f'ParseError: {pe}')
-
-def test_qualify():
-    statement = "INSERT INTO first VALUES((SELECT hello FROM second), world);"
-
-    try:
-        ASTs = sqlglot.parse(statement)
-        print(repr(ASTs))
-        for AST in ASTs:
-            root = sqlglot.optimizer.scope.build_scope(AST)
-            print(root)
-            for scope in root.traverse():
-                print(scope)
+        print(pe)
+    except sqlglot.errors.TokenError as te:
+        print(te)
     except sqlglot.errors.OptimizeError as oe:
-        print(f'----\nOptimizeError: {oe}\n-----')
-
+        print(oe)
+    except:
+        print('GG')
 
 if __name__ == '__main__':
-    #whatever()
-    test_qualify()
\ No newline at end of file
+    whatever()
\ No newline at end of file
-- 
GitLab