From 5b52791b73c740bf94fc5e45c3ebbc77a440351c Mon Sep 17 00:00:00 2001 From: Mesharo <Hecko97@seznam.cz> Date: Wed, 18 Dec 2024 09:00:33 +0100 Subject: [PATCH] linkovani, aktualni stav projektu --- main.py | 137 ++++++++++++++++++++++++++++++++++--------------------- tests.py | 121 +++++++++++++++++++++++++++--------------------- 2 files changed, 154 insertions(+), 104 deletions(-) diff --git a/main.py b/main.py index bf25ab7..2d3e0e5 100644 --- a/main.py +++ b/main.py @@ -94,59 +94,101 @@ def erase_html(code_section: list) -> str: return replace_unrecognized_letters(result) +def can_be_parsed(code: str) -> bool: + try: + expression_tree = sqlglot.parse(code, dialect='postgres') + + if expression_tree == [None]: + return False + + return True + except: + return False + +def can_be_scoped(ast) -> bool: + try: + root = sqlglot.optimizer.scope.build_scope(ast) + return True + except: + return False + +def can_be_qualified(ast) -> bool: + try: + sqlglot.optimizer.qualify.qualify(ast) + + return True + except: + return False + +def solve_with_scopes(ast) -> list: + result = [] + columns = [] + tables = [] + aliases = [] + + root = sqlglot.optimizer.scope.build_scope(ast) + for scope in root.traverse(): + pass + +def solve_without_scopes(ast) -> list: + result = [] + columns = [] + tables = [] + aliases = [] + + tmp = list(ast.find_all(sqlglot.exp.Column)) + for column in tmp: + columns.append(str(column)) + + tmp = list(ast.find_all(sqlglot.exp.Table)) + for table in tmp: + tables.append(str(table)) + + tmp = list(ast.find_all(sqlglot.exp.Alias)) + for alias in tmp: + aliases.append(str(alias)) + + print(f'Cols: {columns}\nTabs: {tables}\nAliases: {aliases}') + + #TODO traversing again using walk and link together + + def analyze(id: str, codes: list) -> tuple: - parsed = 0 - not_parsed = 0 - is_none = 0 - not_empty_columns_tables_aliases = 0 + result = [] + columns_tables = [] + for code_list in codes: all_codes_string = erase_html(code_list) for code in all_codes_string.split(';'): - #print('----------------------') - #print(code) + if not can_be_parsed(code): + continue - correct = False try: expression_tree = sqlglot.parse(code, dialect='postgres') - if expression_tree == [None]: - is_none += 1 - correct = True - continue for ast in expression_tree: - columns = ast.find_all(sqlglot.exp.Column) - tables = ast.find_all(sqlglot.exp.Table) - aliases = ast.find_all(sqlglot.exp.Alias) + if (can_be_qualified(ast)): + sqlglot.optimizer.qualify.qualify(ast) + + if (can_be_scoped(ast)): + columns_tables = solve_with_scopes(ast) - if next(columns, None) is not None: - not_empty_columns_tables_aliases += 1 - continue - if next(tables, None) is not None: - not_empty_columns_tables_aliases += 1 - continue - if next(aliases, None) is not None: - not_empty_columns_tables_aliases += 1 - continue - - correct = True + columns_tables = solve_without_scopes(ast) + + if columns_tables: + result.append(columns_tables) + columns_tables.clear() + except sqlglot.errors.ParseError as pe: - correct = False - #print(f'----\nParseError: {pe}\n-----') + continue except sqlglot.errors.TokenError as te: - correct = False - #print(f'----\nTokenError: {te}\n-----') + continue except sqlglot.errors.OptimizeError as oe: - correct = False - #print(f'----\nOptimizeError: {oe}\n-----') + continue except: - correct = False - - if correct: - parsed += 1 - else: - not_parsed += 1 + continue - return (parsed, not_parsed, is_none, not_empty_columns_tables_aliases) + return (id, result) def run(input_filepath_all_answers: str, input_filepath_postgresql_questions: str, input_filepath_linked: str, input_filepath_codes: str) -> None: """Main function. @@ -164,25 +206,16 @@ def run(input_filepath_all_answers: str, input_filepath_postgresql_questions: st if os.path.isfile(input_filepath_codes): codes = load_code_sections(input_filepath_codes) - parsed = 0 - not_parsed = 0 - is_none = 0 - not_empty_columns_tables_aliases = 0 + count = 0 for key, values in codes.items(): if not values: continue - - tmp = analyze(key, values) - parsed += tmp[0] - not_parsed += tmp[1] - is_none += tmp[2] - not_empty_columns_tables_aliases += tmp[3] - print(f'Parsed: {parsed}, not parsed: {not_parsed}, None: {is_none} (included in Parsed), Found col/table/alias: {not_empty_columns_tables_aliases}') - print('DONE!') - # new dataset: Parsed: 424400, not parsed: 201498, None: 80004 (included in Parsed) - # Parsed: 344396, not parsed: 201498, None: 80004 (not included in Parsed), Found col/table/alias: 280418 + while count < 100: + tmp = analyze(key, values) + print(f'ID: {tmp[0]}, columns and tables: {tmp[1]}') + return if os.path.isfile(input_filepath_linked): diff --git a/tests.py b/tests.py index c88b2f4..57cc7df 100644 --- a/tests.py +++ b/tests.py @@ -78,6 +78,7 @@ def testing2(): print(mydict) +import sqlglot.expressions import sqlglot.optimizer.scope import sqlparse def erasing_backslashes(): @@ -170,66 +171,82 @@ def tmp(): except sqlglot.errors.OptimizeError: expression_tree = undo +def solve_schema(schema: str) -> list: + pass + +def solve_subquery(subquery: str) -> list: + # Rekurze - ast + pass + def whatever(): - expr = "SELECT xd.hello FROM y AS xd" + #expr = "SELECT hello FROM a INNER JOIN b;" + expr = 'SELECT xd, (SELECT a FROM b) FROM c;' + #expr = "INSERT INTO first(xdd, xdd2) VALUES((SELECT x FROM third), 1);" + #expr = """WITH x AS (SELECT a FROM y) SELECT a FROM x""" try: expression_tree = sqlglot.parse(expr) print(repr(expression_tree)) - for tmp in expression_tree: - columns = tmp.find_all(exp.Column) - tables = tmp.find_all(exp.Table) - aliases = tmp.find_all(exp.Alias) - - print('Columns!') - for column in columns: - print(column.name) - - print('Tables!') - for table in tables: - print(table.name) - - print('Aliases!') - for alias in aliases: - print(alias.name) - - """ - for node in expression_tree.args['expressions']: - if isinstance(node, exp.Column): - if (node.args['this']): - print(f'Column name: {node.args["this"]}') - if (node.args['table']): - print(f'from table: {node.args["table"]}') + for ast in expression_tree: + """ + tmp = list(ast.find_all(sqlglot.exp.Schema)) + for column in tmp: + print(column) + + tmp = list(ast.find_all(sqlglot.exp.Column)) + for column in tmp: + columns.append(str(column)) + + tmp = list(ast.find_all(sqlglot.exp.Table)) + for table in tmp: + tables.append(str(table)) + + tmp = list(ast.find_all(sqlglot.exp.Alias)) + for alias in tmp: + aliases.append(str(alias))""" + + walk = [] + sub_result = [] + + tmp = list(ast.find_all(sqlglot.exp.Schema)) + for schema in tmp: + pass + #sub_result.append(f'Schema:{solve_schema(schema)}') + #ast.args['schema'].replace(sqlglot.exp.Alias(this=sqlglot.exp.Identifier("SOLVED"))) + #print('first (xdd, xdd2)' == 'first(xdd, xdd2)') + + tmp = list(ast.find_all(sqlglot.exp.Subquery)) + for subquery in tmp: + pass + #sub_result.append(f'Subquery:{solve_subquery(subquery)}') + #ast.args['subquery'].replace(sqlglot.exp.Alias(this=sqlglot.exp.Identifier("SOLVED"))) + + for node in ast.walk(bfs=False): + if isinstance(node, sqlglot.expressions.Table): + walk.append(f'Table:{str(node)}') + continue + if isinstance(node, sqlglot.expressions.Column): + walk.append(f'Column:{str(node)}') + continue + if isinstance(node, sqlglot.expressions.Join): + walk.append(f'Join') + continue + #if isinstance(node, sqlglot.expressions.Alias): + #aliases.append(str(node)) + + print(expr) + print(f'Walk: {walk}') + #print(f'Cols: {columns}\nTabs: {tables}\nAliases: {aliases}') - if isinstance(node, exp.Alias): - if (node.args['this']): - print(f'Column name: {node.args["this"]}') - - if isinstance(node, exp.Table): - if (node.args['this']): - print(f'Column name: {node.args["this"]}') - if (node.args['alias']): - print(f'from table: {node.args["alias"]}') - """ except sqlglot.errors.ParseError as pe: - print(f'ParseError: {pe}') - -def test_qualify(): - statement = "INSERT INTO first VALUES((SELECT hello FROM second), world);" - - try: - ASTs = sqlglot.parse(statement) - print(repr(ASTs)) - for AST in ASTs: - root = sqlglot.optimizer.scope.build_scope(AST) - print(root) - for scope in root.traverse(): - print(scope) + print(pe) + except sqlglot.errors.TokenError as te: + print(te) except sqlglot.errors.OptimizeError as oe: - print(f'----\nOptimizeError: {oe}\n-----') - + print(oe) + except: + print('GG') if __name__ == '__main__': - #whatever() - test_qualify() \ No newline at end of file + whatever() \ No newline at end of file -- GitLab