Skip to content
Snippets Groups Projects
plots.py 15.75 KiB
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Any, Dict, Tuple
import re

from .processing import *
from .utils import from_class_to_text

# %%
def PlotFilter (filter: np.ndarray, sampling_rate: int = 1000):

    f = plt.figure()

    plt.subplot(2,1,1)
    plt.title("Impulse response")
    plt.xlabel("Sample")
    plt.ylabel("Amplitude")
    plt.stem(filter)

    plt.subplot(2,1,2)
    freqs, fft = FFT(filter, sampling_rate, 9 * len(filter))
    plt.plot(freqs, fft)
    plt.title("Frequency Response")
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Amplitude")
    plt.yscale("log")
    plt.grid()

    plt.tight_layout()

    return f

# %%
def correlation_plot (lag: np.ndarray, correlation: np.ndarray, name: str = None, save: bool = False):
    
    f = plt.figure()

    plt.plot(lag, correlation)
    
    plt.grid()
    plt.xlabel("Lag")
    plt.ylabel("Correlation")

    if type(name) == str:
        plt.title(name)

    plt.tight_layout()

    if type(name) == str and save:
        escaped = re.sub(r"[^\w\d\(\)\ ]", "_", name)
        folder = Path("figures\\correlation\\")
        if not folder.exists():
            folder.mkdir(parents=True)
        plt.savefig(folder.joinpath(escaped + ".png"))

    return f

# %%
def FFT_plot (freqs: np.ndarray, fft: np.ndarray, log: bool = False, name: str = None, save: bool = False):

    f = plt.figure()

    plt.plot(freqs, fft)
    
    plt.grid()
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Amplitude" if not log else "Amplitude (dB)")
    if log:
        plt.yscale("log")

    if type(name) == str:
        plt.title(name)

    plt.tight_layout()

    if type(name) == str and save:
        escaped = re.sub(r"[^\w\d\(\)\ ]", "_", name)
        folder = Path("figures\\FFT\\")
        if not folder.exists():
            folder.mkdir(parents=True)
        plt.savefig(folder.joinpath(escaped + ".png"))

    return f

# %%
def Time_plot (time: np.ndarray, signal: np.ndarray, name: str = None, save: bool = False):

    f = plt.figure()

    plt.plot(time, signal)
    
    plt.grid()
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude" if name is None else name)

    if type(name) == str:
        plt.title(name)

    plt.tight_layout()

    if type(name) == str and save:
        escaped = re.sub(r"[^\w\d\(\)\ ]", "_", name)
        folder = Path("figures\\Time\\")
        if not folder.exists():
            folder.mkdir(parents=True)
        plt.savefig(folder.joinpath(escaped + ".png"))

    return f

#%%
def FFT_by_plot (cls: List[Any], freqs: np.ndarray, ffts: List[np.ndarray], log: bool = False, name: str = None, save: bool = False):

    f = plt.figure()
    ax = plt.subplot(1,1,1)

    image = np.float64(ffts)
    if log:
        image = 10 * np.log10(image)

    extent = (freqs[0], freqs[-1], 0, len(cls))
    im1 = plt.imshow(image, "inferno", origin="lower", aspect="auto", extent=extent)
    
    plt.colorbar(im1, label="Magnitude" + (" (dB)" if log else ""))
    df = freqs[-1] / len(freqs)
    ax.set_xlabel(f"Fequency (Hz) [$\\Delta f = {df:.3e}\ Hz$]")
    ax.set_yticks(range(len(cls)), labels=cls)
    ax.grid(axis="y")

    plt.suptitle(name)

    plt.tight_layout()

    if type(name) == str and save:
        escaped = re.sub(r"[^\w\d\(\)\ ]", "_", name)
        folder = Path("figures\\FFT_by\\")
        if not folder.exists():
            folder.mkdir(parents=True)
        plt.savefig(folder.joinpath(escaped + ".png"))

    return f

# %% -
def savefig (name: str):
    escaped = re.sub(r"[^\w\d\(\)\ \\\[\]=\.,]", "_", name)
    path = Path(f"figures\\{escaped}.png")
    if not path.parent.exists():
        path.parent.mkdir(parents=True)
    plt.savefig(path)

# %%
def Spectrogram_plot (
        time: np.ndarray, 
        Sx: np.ndarray, 
        STFT: scs.ShortTimeFFT, 
        log: bool = False, 
        name: str = "Spectrogram", 
        save: bool = False
    ):

    f = plt.figure(figsize=(12.8,7.2))

    extent = STFT.extent(len(time))
    to_lo, to_hi = extent[:2]
    
    ax1 = f.subplots(1, 1)
    ax1.set(
        xlabel=rf"Time (s) [$\Delta t = {STFT.delta_t:g}\ s$]",
        ylabel=rf"Frequency (Hz) [$\Delta f = {STFT.delta_f:g}\ Hz$]",
        xlim=(to_lo, to_hi)
    )

    image = np.abs(Sx)
    if log:
        image = 20 * np.log10(np.abs(Sx))

    im1 = ax1.imshow(
        image, extent=extent, 
        cmap="inferno", origin="lower", aspect="auto"
    )
    f.colorbar(im1, label=fr"Magnitude $S_x(t, f)$" + (" (db)" if log else ""))

    f.suptitle(name)
    f.tight_layout()

    if save:
        savefig("Spectrogram\\" + name)

    return f, extent

def plot_unique_position(data:pd.DataFrame, axis:plt.Axes, byvariable:str, Motor:int) -> plt.Axes:
    """plots on secondary axis a unique motor position run when multiple runs are present in the dataframe in argument
    the runs are differenciated with the 'byvariable' value

    Args:
        data (pd.DataFrame): data containing the Position _A{Motor} column
        axis (plt.Axes): axis to duplicate 
        byvariable (str): common column in the dataframe, differenciating the runs
        Motor (int): Motor's position to plot

    Returns:
        plt.Axes: new axis with ploted position
    """
    ax2 = axis.twinx()
    position_df = data.loc[data[byvariable] == data.iloc[0][byvariable]] # look for the data with condition 'byvariable' == (first value found of byvariable)
    lin = sns.lineplot(data=position_df, x='Sample_time',y=f'Position_A{Motor}', color='grey', alpha = 0.1, ax=ax2,label=f'Axis {Motor} position',legend=False)
    ax2.set_ylabel(f'Motor {Motor} angle (°)')
    return lin

def check_arguments(plot_type:str, variable:str, byvariable:str):
    """ Checks that the variables to plot are suited for ploting
    """
    if plot_type not in ['Lineplot', 'Boxplot']:
        print('Bad plot type argument')
        return False
    if variable not in ['Current', 'Temperature', 'Position_Error']:
        print('Bad variable argument')
        return False
    if byvariable not in ['Class', 'Speed']:
        print('Bad by variable argument')
        return False
    return True

def plot_all_axis(DB, variable:str, byvariable:str, byvalues:List, MovingMotor:int, Speed:int = 50, Class:int = 0, plot_type:str = 'Lineplot', doposition:bool = False, saving:bool = False):
    """Plots variable's data from the Dataframe on each robot axis (6 subplots), differenciating the data by 'byvariable' column's name.
    Create a 'Lineplot' or a 'Boxplot', depending on the argument plot
    Which motor's data to plot is defined by the arg Motor, same with Speed.
    Choose to plot the motor's position on a secondary axis, figure saving and showing with the boolean arguments

    Args:
        data (pd.DataFrame): Pandas dataframe containing the data
        variable (str): Variable to plot
        byvariable (str): Variable that discriminates the data
        Motor (int): Motor number
        Speed (int): Motor speed
        plot_type (str, optional): 'Lineplot' or 'Boxplot' : format of the data plot. Defaults to 'Lineplot'
        doposition (bool, optional): motor's position plot. Defaults to False.
        saving (bool, optional): figure saving. Defaults to False.
    """
    if not check_arguments(plot_type, variable, byvariable):
        return
    
    columns = [byvariable,'Sample_time',f'Position_A{MovingMotor}',*[f'{variable}_A{j}'for j in range(1,7,1)]]
    dataframe = pd.DataFrame()  # Data gathering
    rms = []
    for v in byvalues:
        if byvariable == 'Speed':
            df = DB.robot(2).by_class(Class).by_speed(v).by_moving_motor(MovingMotor).select_column(*columns).run()
            df['Speed'] = df.Speed.astype('category')
            df['Sample_time'] -= df['Sample_time'].min()    # limits data to 2-3 iterations
            dataframe = pd.concat([dataframe, df[0:500]])
        if byvariable == 'Class':
            df = DB.robot(2).by_class(v).by_speed(Speed).by_moving_motor(MovingMotor).run()
            df['Class'] = df.Class.astype('category')
            df['Sample_time'] -= df['Sample_time'].min()
            dataframe = pd.concat([dataframe, df[300:900]])     # limits data to 2-3 iterations
        rms.append(np.sqrt(np.mean(df[f'{variable}_A{MovingMotor}']**2)))
    print(f'Motor {MovingMotor} runs rms :',rms)
    
    fig = plt.figure(figsize=(3*6,2*4)) # Data plot
    for i in range(1,7,1):
        axis = plt.subplot(2,3,i)
        if(plot_type=='Boxplot'):
            sns.boxplot(data=dataframe, y=f'{variable}_A{i}', hue=byvariable, ax=axis)
        if(plot_type=='Lineplot'):
            lin1 = sns.lineplot(data= dataframe, x='Sample_time', y=f'{variable}_A{i}', hue=byvariable, ax=axis, alpha=0.6)
            if doposition :
                lin2 = plot_unique_position(dataframe, axis, byvariable, MovingMotor)
                axis.legend(fontsize=8, loc='upper right', handles=lin1.get_lines()+lin2.get_lines())
            else:
                axis.legend(fontsize=8, loc='upper right')
        axis.set_title(f'Motor {i}')
        
        
    if byvariable == 'Speed':
        fig.suptitle(f"{plot_type} of {variable} by {byvariable} for Axis {MovingMotor} moving loaded with {from_class_to_text(Class)}")
    if byvariable == 'Class':
        fig.suptitle(f"{plot_type} of {variable} by {byvariable} for Axis {MovingMotor} moving at {Speed}% Speed")
        
    fig.tight_layout()
    
    if saving:
        figures = Path(f"figures/{variable} by motor by {byvariable}")
        if not figures.exists():
            figures.mkdir(parents=True)
        fig.savefig(f"figures/{variable} by motor by {byvariable}/{plot_type}_of_{variable}_by{byvariable}_Axis{MovingMotor}.png")
        
    return fig


def plot_moving_axes(DB, variable:str, byvariable:str, byvalues:List, Speed:int = 50, Class:int = 0, plot_type:str = 'Lineplot', doposition:bool = False, saving:bool = False):
    """Plots the current of the motor responsible of axis movement for the axes moving (ex : A1 moving - current Motor 1, A2 moving - current Motor 2,...)
    differenciating the data by 'byvariable' column's name.
    Create a 'Lineplot' or a 'Boxplot', depending on the argument plot
    Which motor's data to plot is defined by the arg Motor, same with Speed.
    Choose to plot the motor's position on a secondary axis, figure saving and showing with the boolean arguments
    
    Args:
        data (pd.DataFrame): Pandas dataframe containing the data
        variable (str): Variable to plot
        byvariable (str): Variable that discriminates the data
        Motor (int): Motor number
        Speed (int): Motor speed
        plot_type (str, optional): 'Lineplot' or 'Boxplot' : format of the data plot. Defaults to 'Lineplot'
        doposition (bool, optional): motor's position plot. Defaults to False.
        saving (bool, optional): figure saving. Defaults to False.
    """
    
    if not check_arguments(plot_type, variable, byvariable):
        return
    
    fig = plt.figure(figsize=(3*6,2*4))
    for Motor in range (1,7,1):
        dataframe = pd.DataFrame()   # Data gathering
        columns = [byvariable,'Sample_time',f'{variable}_A{Motor}',f'Position_A{Motor}']
        rms = []
        for v in byvalues:
            if byvariable == 'Speed':
                df = DB.robot(2).by_class(Class).by_speed(v).by_moving_motor(Motor).select_column(*columns).run()
                df['Speed'] = df.Speed.astype('category')
                df['Sample_time'] -= df['Sample_time'].min()
                dataframe = pd.concat([dataframe, df[0:500]])   # limits data to 2-3 iterations
            if byvariable == 'Class':
                df = DB.robot(2).by_class(v).by_speed(Speed).by_moving_motor(Motor).select_column(*columns).run()
                df['Class'] = df.Class.astype('category')
                df['Sample_time'] -= df['Sample_time'].min()
                dataframe = pd.concat([dataframe, df[300:900]])
            rms.append(np.sqrt(np.mean(df[f'{variable}_A{Motor}']**2))) # limits data to 2-3 iterations
        print(f'Motor {Motor} runs rms :',rms)
                
        axis = plt.subplot(2,3,Motor)
        if(plot_type=='Boxplot'):
            sns.boxplot(data=dataframe, y=f'{variable}_A{Motor}', hue=byvariable, ax=axis)
        if(plot_type=='Lineplot'):
            lin1 = sns.lineplot(data= dataframe, x='Sample_time', y=f'{variable}_A{Motor}', hue=byvariable, ax=axis, alpha=0.6)
            lin1.set_ylabel(f'{variable}_A{Motor}, Axis {Motor} moving')
            if doposition :
                lin2 = plot_unique_position(dataframe, axis, byvariable, Motor)
                axis.legend(fontsize=8, loc='upper right', handles=lin1.get_lines()+lin2.get_lines())
            else:
                axis.legend(fontsize=8, loc='upper right')
        axis.set_title(f'Motor {Motor}')
        
    
    if byvariable == 'Speed':
        fig.suptitle(f"{plot_type} of {variable} by {byvariable} loaded with {from_class_to_text(Class)}")
    if byvariable == 'Class':
        fig.suptitle(f"{plot_type} of {variable} by {byvariable} moving at {Speed}% Speed")
        
    fig.tight_layout()
    
    if saving:
        figures = Path(f"figures/{variable} by moving motors by {byvariable}")
        if not figures.exists():
            figures.mkdir(parents=True)
        fig.savefig(f"{figures}/{plot_type}_of_{variable}_by{byvariable}{str(byvalues)}_allAxes.png")
        
    return fig    

def plot_grouped_load(DB, variable:str, Classes:List[List], Motor:int, Speed:int, doposition:bool = False, saving:bool = False) -> plt.Figure:
    
    """Plots variable's data from the DataBase, grouping the curves by load classes.
    Each list in Classes will create a subplot, with the curves of the classes
    Which motor's data to plot is defined by the arg Motor, same with Speed.
    Choose to plot the motor's position on a secondary axis and figure saving with the boolean arguments

    Args:
        DB (_type_): Database containing the data
        variable (str): Column of the database to plot
        Classes (List[List]): List of list of loads classes. 
        Motor (int): Motor number
        Speed (int): Motor speed
        doposition (bool, optional): motor's position plot. Defaults to False.
        saving (bool, optional): figure saving. Defaults to False.

    Returns:
        plt.Figure: created figure
    """
    columns = ['Class','Sample_time',f'Position_A{Motor}',*[f'{variable}_A{j}'for j in range(1,7,1)]]
    fig = plt.figure(figsize=(len(Classes)*5,4))

    for i, cla in enumerate(Classes):
        axis = plt.subplot(1,len(Classes),i+1)
        dataframe = pd.DataFrame()
        rms = []
        for c in cla:
            # Data gathering from the database
            df = DB.robot(2).by_class(c).by_speed(Speed).by_moving_motor(Motor).select_column(*columns).run()
            df['Sample_time'] -= df['Sample_time'].min()
            df['Class'] = df.Class.astype('category')
            dataframe = pd.concat([dataframe, df[300:900]]) # limits data to 2-3 iterations
            rms.append(np.sqrt(np.mean(df[f'{variable}_A{Motor}']**2)))
        print(f'Motor {Motor} runs rms :',rms)
            
        lin1 = sns.lineplot(data=dataframe,x='Sample_time', y=f'{variable}_A{Motor}', hue='Class', ax=axis, alpha=0.6) # data plot
        
        if doposition:
            lin2 = plot_unique_position(dataframe, axis, 'Class', Motor)
            axis.legend(fontsize=8, loc='upper right', handles=lin1.get_lines()+lin2.get_lines())
        else:
            axis.legend(fontsize=8, loc='upper right')
        axis.legend(fontsize=8, loc='upper right', handles=lin1.get_lines()+lin2.get_lines())
        
    fig.suptitle(f'{variable} of Motor {Motor} running at {Speed}% speed grouped by load')
    fig.tight_layout()
    
    if saving:
        figures = Path(f"figures/{variable} by motor by grouped Load")
        if not figures.exists():
            figures.mkdir(parents=True)
        fig.savefig(f"figures/{variable} by motor by grouped Load/{variable}_A{Motor}_by_grouped_load.png")
    
    return fig