From b077f442f4b30c14c66ba3e99e23632f4e73875a Mon Sep 17 00:00:00 2001
From: VulcanixFR <vulcanix.gamingfr@gmail.com>
Date: Thu, 18 Jul 2024 15:44:32 +0200
Subject: [PATCH] Added spectrograms and RMS to processing tools

---
 .vscode/settings.json |   3 +
 tools/plots.py        |  68 ++++++++++++++--
 tools/processing.py   | 178 +++++++++++++++++++++++++++---------------
 3 files changed, 176 insertions(+), 73 deletions(-)
 create mode 100644 .vscode/settings.json

diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 0000000..b881eff
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,3 @@
+{
+    "python.analysis.autoImportCompletions": true
+}
\ No newline at end of file
diff --git a/tools/plots.py b/tools/plots.py
index 1d95703..5d74701 100644
--- a/tools/plots.py
+++ b/tools/plots.py
@@ -46,10 +46,11 @@ def correlation_plot (lag: np.ndarray, correlation: np.ndarray, name: str = None
     plt.tight_layout()
 
     if type(name) == str and save:
-        folder = Path("figures\\autocorrelation\\")
+        escaped = re.sub(r"[^\w\d\(\)\ ]", "_", name)
+        folder = Path("figures\\correlation\\")
         if not folder.exists():
             folder.mkdir(parents=True)
-        plt.savefig(folder.joinpath(name + ".png"))
+        plt.savefig(folder.joinpath(escaped + ".png"))
 
     return f
 
@@ -72,10 +73,11 @@ def FFT_plot (freqs: np.ndarray, fft: np.ndarray, log: bool = False, name: str =
     plt.tight_layout()
 
     if type(name) == str and save:
+        escaped = re.sub(r"[^\w\d\(\)\ ]", "_", name)
         folder = Path("figures\\FFT\\")
         if not folder.exists():
             folder.mkdir(parents=True)
-        plt.savefig(folder.joinpath(name + ".png"))
+        plt.savefig(folder.joinpath(escaped + ".png"))
 
     return f
 
@@ -87,7 +89,7 @@ def Time_plot (time: np.ndarray, signal: np.ndarray, name: str = None, save: boo
     plt.plot(time, signal)
     
     plt.grid()
-    plt.xlabel("Frequency (Hz)")
+    plt.xlabel("Time (s)")
     plt.ylabel("Amplitude" if name is None else name)
 
     if type(name) == str:
@@ -96,10 +98,11 @@ def Time_plot (time: np.ndarray, signal: np.ndarray, name: str = None, save: boo
     plt.tight_layout()
 
     if type(name) == str and save:
-        folder = Path("figures\\FFT\\")
+        escaped = re.sub(r"[^\w\d\(\)\ ]", "_", name)
+        folder = Path("figures\\Time\\")
         if not folder.exists():
             folder.mkdir(parents=True)
-        plt.savefig(folder.joinpath(name + ".png"))
+        plt.savefig(folder.joinpath(escaped + ".png"))
 
     return f
 
@@ -117,7 +120,8 @@ def FFT_by_plot (cls: List[Any], freqs: np.ndarray, ffts: List[np.ndarray], log:
     im1 = plt.imshow(image, "inferno", origin="lower", aspect="auto", extent=extent)
     
     plt.colorbar(im1, label="Magnitude" + (" (dB)" if log else ""))
-    ax.set_xlabel("Fequency (Hz)")
+    df = freqs[-1] / len(freqs)
+    ax.set_xlabel(f"Fequency (Hz) [$\\Delta f = {df:.3e}\ Hz$]")
     ax.set_yticks(range(len(cls)), labels=cls)
     ax.grid(axis="y")
 
@@ -132,4 +136,52 @@ def FFT_by_plot (cls: List[Any], freqs: np.ndarray, ffts: List[np.ndarray], log:
             folder.mkdir(parents=True)
         plt.savefig(folder.joinpath(escaped + ".png"))
 
-    return f
\ No newline at end of file
+    return f
+
+# %% -
+def savefig (name: str):
+    escaped = re.sub(r"[^\w\d\(\)\ \\\[\]=\.,]", "_", name)
+    path = Path(f"figures\\{escaped}.png")
+    if not path.parent.exists():
+        path.parent.mkdir(parents=True)
+    plt.savefig(path)
+
+# %%
+def Spectrogram_plot (
+        time: np.ndarray, 
+        Sx: np.ndarray, 
+        STFT: scs.ShortTimeFFT, 
+        log: bool = False, 
+        name: str = "Spectrogram", 
+        save: bool = False
+    ):
+
+    f = plt.figure(figsize=(12.8,7.2))
+
+    extent = STFT.extent(len(time))
+    to_lo, to_hi = extent[:2]
+    
+    ax1 = f.subplots(1, 1)
+    ax1.set(
+        xlabel=rf"Time (s) [$\Delta t = {STFT.delta_t:g}\ s$]",
+        ylabel=rf"Frequency (Hz) [$\Delta f = {STFT.delta_f:g}\ Hz$]",
+        xlim=(to_lo, to_hi)
+    )
+
+    image = np.abs(Sx)
+    if log:
+        image = 20 * np.log10(np.abs(Sx))
+
+    im1 = ax1.imshow(
+        image, extent=extent, 
+        cmap="inferno", origin="lower", aspect="auto"
+    )
+    f.colorbar(im1, label=fr"Magnitude $S_x(t, f)$" + (" (db)" if log else ""))
+
+    f.suptitle(name)
+    f.tight_layout()
+
+    if save:
+        savefig("Spectrogram\\" + name)
+
+    return f, extent
\ No newline at end of file
diff --git a/tools/processing.py b/tools/processing.py
index 41ce789..346ce9b 100644
--- a/tools/processing.py
+++ b/tools/processing.py
@@ -1,4 +1,4 @@
-from typing import List, Tuple, Dict
+from typing import List, Tuple, Dict, Any
 import pandas as pd
 import numpy as np
 import matplotlib.pyplot as plt
@@ -182,75 +182,33 @@ def find_plateau (mask: np.ndarray, start: int, backwards: bool = False, W: int
 
     return a, b, w
 
-#%%
-def make_periodic_2 (X: np.ndarray, Y_Current: np.ndarray, Y_Position: np.ndarray, W = 5, plot = False):
-
-    Y2 = np.diff(Y_Position - Y_Position.mean())
-    Y2 /= np.max(Y2)
-    # Findig all places where the signal is not changing
-    mask = (np.abs(Y2) > 0.1) * 1
-
-    if plot:
-        plt.figure()
-        plt.plot(X[:-1],np.abs(Y2), label="$x'(t)$")
-        plt.plot(X[:-1],mask * 0.1, label="mask")
-
-    start = 0
-    plateaus = []
-    a,b,w = find_plateau(mask, start, False, W)
-    while w > 0:
-        plateaus.append([a,b,w])
-        start = b + 1
-        a,b,w = find_plateau(mask, start, False, W)
-
-    plateaus = np.int16(plateaus)
-
-    if plot:
-        _a = []
-        _b = []
-        for a,b,w in plateaus:
-            _a.append(X[a])
-            _b.append(X[b])
-        plt.scatter(_a, np.ones(len(_a)) * 0.5, marker="*", c="g", label="Plateau start")
-        plt.scatter(_b, np.ones(len(_a)) * 0.5, marker="*", c="m", label="Plateau end")
-
-    max_w = plateaus[:,2].max()
-    
-    stable = plateaus[np.where(np.abs((plateaus[:,2] - max_w) / max_w) < 0.2)]
-
-    zeros = stable[::2]
-
-    if plot:
-        _z = [ X[a] for a,b,w in zeros ]
-        plt.scatter(_z, np.zeros(len(_z)), c='r', marker="x", label="Zero")
-
-        plt.xlabel("Time (s)")
-        plt.ylabel("Amplitude (normalized)")
-        plt.tight_layout()
-        plt.legend()
-
-    # start = int(0.5 * (zeros[0][0]  + zeros[0][1] ))
-    # end   = int(0.5 * (zeros[-1][0] + zeros[-1][1]))
-    start = zeros[0][0]
-    end = zeros[-1][0]
-
-    periods = len(zeros) - 1
-    fmov = periods / (X[end] - X[start])
-
-    XP = X[start:end].copy()
-    CP = Y_Current[start:end].copy()
-    PP = Y_Position[start:end].copy()
-
-    return XP, CP, PP, fmov, periods, plateaus
-
 # %% -
 def FFT_by (
         data: pd.DataFrame, by: str, 
-        time: str, position: str, current: str
+        time: str, position: str, current: str, Ts=4
     ):
+    """Makes an FFT for each of the values taken by the "by" variable
+
+    Args:
+        data (pd.DataFrame): Data to analyse
+        by (str): Name of the attribute to compare
+        time (str): Name of the time column
+        position (str): Name of the position column
+        current (str): Name of the current column
+        Ts (int): Sampling period in ms 
+
+    Returns:
+        Tuple[cls, freqs, ffts, Tuple[times, positions, currents]]
+        cls (List[Any]): Values taken by "by"
+        freqs (ndarray): Frequency axis
+        ffts (List[ndarray]): FFT for each cls
+        times (List[ndarray]): Time for each cls
+        positions (List[ndarray]): Position for each cls
+        currents (List[ndarray]): Current for each cls
+    """
 
     # Define variables
-    cls: List[Any] = data[by].unique()
+    cls: List[Any] = np.sort(data[by].unique())
     freqs: np.ndarray = None
     ffts: List[np.ndarray] = []
 
@@ -267,7 +225,7 @@ def FFT_by (
         P = subdata[position].to_numpy()
         C = subdata[current].to_numpy()
 
-        TP, CP, PP, fmov, periods, plateaus = make_periodic_2(T, C, P)
+        TP, [CP], PP, fmov, periods, plateaus = make_periodic(T, P, [C])
 
         times.append(TP)
         positions.append(PP)
@@ -333,3 +291,93 @@ def make_periodic (time: np.ndarray, position: np.ndarray, currents: List[np.nda
     PP = position[start:end].copy()
 
     return XP, CP, PP, fmov, periods, plateaus
+
+
+# %%
+def RMS (signal: np.ndarray) -> float:
+    """Computes the root mean square value of the given signal
+
+    Args:
+        signal (np.ndarray): Input signal
+
+    Returns:
+        float: The RMS of the signal
+    """
+    return np.sqrt(np.mean(signal ** 2))
+
+# %%
+def moving_rms (
+        time: np.ndarray, 
+        signal: np.ndarray, 
+        window: float = 2,
+        overlap: float = 0.5
+    ) -> Tuple[np.ndarray, np.ndarray]:
+    """Computes the RMS value of the signal using a moving window, 
+    with 50% overlap. The RMS sampling rate is window / 2.
+
+    Args:
+        time (np.ndarray): Time axis
+        signal (np.ndarray): Signal to compute the RMS from
+        window (float, optional): Width of the window (in s). Defaults to 2s.
+        overlap (float): percentage of the window overlapping with the previous one
+        
+    Returns:
+        Tuple[np.ndarray, np.ndarray]: Time and RMS axis
+    """
+
+    Ts = time[1] - time[0] # Sampling rate
+    rms_width = int(window / Ts) # samples
+    rms_hop = max(1, min(int((1 - overlap) * rms_width), rms_width))
+    hops = np.arange(0, len(time), rms_hop)
+    rms_segments = len(hops) # With overlap
+
+    rms_time = hops * Ts + time[0] # Time axis
+    rms_over_time = np.zeros(rms_segments) # RMS axis
+
+    for k in range(rms_segments):
+        a = hops[k]
+        b = a + int(rms_width)
+        rms_over_time[k] = RMS(signal[a:b])
+
+    return rms_time, rms_over_time
+
+# %% - 
+def Spectrogram (
+        signal: np.ndarray, 
+        widow_width: float = 1, Ts: float = 0.004, no_overlap=False
+    ) -> Tuple[np.ndarray, scs.ShortTimeFFT]:
+    
+    """Returns a 2D Spectrogram of the signal
+
+    Args:
+        time (np.ndarray): The time axis
+        signal (np.ndarray): The signal to analyse
+        widow_width (float): The length of the window (in s, 1 by defaut)
+        Ts (float): The sampling rate (in s, 0.004 by default)
+        no_overlap (bool): Disables window overlapping
+
+    Returns:
+        Tuple[np.ndarray, scs.ShortTimeFFT: The 2D spectrogram and the STFT object
+    """
+
+    # Sampling frequency
+    Fs = 1 / Ts
+
+    # Width in n° of samples
+    width = int(widow_width / Ts)
+
+    # Overlap
+    hop = width if no_overlap else width // 4
+
+    # Window
+    window = np.hanning(width)
+
+    # Zero-padding
+    pad = len(signal) * 5
+
+    # Get the spectrogram
+    SFT = scs.ShortTimeFFT(window, hop, Fs, mfft=pad, scale_to="magnitude")
+    Sx = SFT.stft(signal)    
+
+    return Sx, SFT
+
-- 
GitLab