2025-03-11 21:49:50 +08:00
|
|
|
|
import time
|
2025-03-08 20:25:01 +08:00
|
|
|
|
import numpy as np
|
2025-03-11 21:49:50 +08:00
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
from reservoirpy import datasets
|
|
|
|
|
|
|
|
|
|
|
|
def print_exec_time(func):
|
|
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
|
|
start = time.time()
|
|
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
|
|
print(f"Function {func.__name__} executed in {time.time() - start:.4f} seconds.")
|
|
|
|
|
|
return result
|
|
|
|
|
|
return wrapper
|
2025-03-08 20:25:01 +08:00
|
|
|
|
|
|
|
|
|
|
def add_noise(data, noise_type='gaussian', intensity=0.1, **kwargs):
|
|
|
|
|
|
"""
|
|
|
|
|
|
为输入数据添加噪声。
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
data: numpy 数组, 原始数据.
|
|
|
|
|
|
noise_type: str, 噪声类型,可选 'gaussian'(高斯白噪声)、'colored' 或 '1/f'(色噪声)、'impulse'(脉冲噪声)、'sine'(正弦噪声)。
|
|
|
|
|
|
intensity: float, 噪声强度.
|
|
|
|
|
|
kwargs: 针对某些噪声类型的额外参数,例如:
|
|
|
|
|
|
- 对于 'impulse': prob (脉冲发生概率, 默认 0.01)
|
|
|
|
|
|
- 对于 'sine' : freq (正弦信号频率, 默认 5)
|
|
|
|
|
|
返回:
|
|
|
|
|
|
添加噪声后的数据.
|
|
|
|
|
|
"""
|
|
|
|
|
|
noise_type = noise_type.lower()
|
|
|
|
|
|
|
|
|
|
|
|
# 确保噪声与数据形状匹配,处理多维数据
|
|
|
|
|
|
if data.ndim > 1:
|
|
|
|
|
|
n_samples, n_dimensions = data.shape
|
|
|
|
|
|
else:
|
|
|
|
|
|
n_samples = len(data)
|
|
|
|
|
|
n_dimensions = 1
|
|
|
|
|
|
|
|
|
|
|
|
if noise_type in ['gaussian', 'gaussian white']:
|
|
|
|
|
|
noise = intensity * np.random.randn(*data.shape)
|
|
|
|
|
|
return data + noise
|
|
|
|
|
|
|
|
|
|
|
|
elif noise_type in ['colored', '1/f']:
|
|
|
|
|
|
# 对于多维数据,为每个维度单独生成1/f噪声
|
|
|
|
|
|
if data.ndim > 1:
|
|
|
|
|
|
noise = np.zeros_like(data)
|
|
|
|
|
|
for dim in range(n_dimensions):
|
|
|
|
|
|
f = np.fft.fftfreq(n_samples)
|
|
|
|
|
|
f[0] = 1e-10
|
|
|
|
|
|
noise_complex = np.random.randn(n_samples) + 1j * np.random.randn(n_samples)
|
|
|
|
|
|
factor = intensity / np.sqrt(np.abs(f))
|
|
|
|
|
|
noise[:, dim] = np.fft.ifft(noise_complex * factor).real
|
|
|
|
|
|
else:
|
|
|
|
|
|
f = np.fft.fftfreq(n_samples)
|
|
|
|
|
|
f[0] = 1e-10
|
|
|
|
|
|
noise_complex = np.random.randn(n_samples) + 1j * np.random.randn(n_samples)
|
|
|
|
|
|
factor = intensity / np.sqrt(np.abs(f))
|
|
|
|
|
|
noise = np.fft.ifft(noise_complex * factor).real
|
|
|
|
|
|
return data + noise
|
|
|
|
|
|
|
|
|
|
|
|
elif noise_type == 'impulse':
|
|
|
|
|
|
noise = np.zeros_like(data)
|
|
|
|
|
|
prob = kwargs.get('prob', 0.01)
|
|
|
|
|
|
mask = np.random.rand(*data.shape) < prob
|
|
|
|
|
|
impulse = intensity * (2 * np.random.rand(*data.shape) - 1)
|
|
|
|
|
|
noise[mask] = impulse[mask]
|
|
|
|
|
|
return data + noise
|
|
|
|
|
|
|
|
|
|
|
|
elif noise_type == 'sine':
|
|
|
|
|
|
t = np.arange(n_samples)
|
|
|
|
|
|
freq = kwargs.get('freq', 5)
|
|
|
|
|
|
|
|
|
|
|
|
# 对于多维数据,确保正弦波形状正确
|
|
|
|
|
|
if data.ndim > 1:
|
|
|
|
|
|
sine_wave = intensity * np.sin(2 * np.pi * freq * t / n_samples).reshape(-1, 1)
|
|
|
|
|
|
# 扩展到与数据相同的维度
|
|
|
|
|
|
sine_wave = np.tile(sine_wave, (1, n_dimensions))
|
|
|
|
|
|
else:
|
|
|
|
|
|
sine_wave = intensity * np.sin(2 * np.pi * freq * t / n_samples)
|
|
|
|
|
|
return data + sine_wave
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError("未知的噪声类型。可选项:'gaussian', 'colored' (1/f), 'impulse', 'sine'")
|
|
|
|
|
|
|
2025-03-27 21:53:48 +08:00
|
|
|
|
def load_data(system='lorenz',
|
|
|
|
|
|
init='random',
|
|
|
|
|
|
noise=None,
|
|
|
|
|
|
intensity=0.1,
|
|
|
|
|
|
h=0.01,
|
|
|
|
|
|
n_timesteps=10000,
|
|
|
|
|
|
transient=1000,
|
|
|
|
|
|
normlization=True, **kwargs):
|
2025-03-08 20:25:01 +08:00
|
|
|
|
"""
|
|
|
|
|
|
加载混沌系统数据.
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
system: str, 混沌系统类型, 可选 'lorenz', 'rossler', 'multiscroll', 'kuramoto_sivashinsky'。
|
|
|
|
|
|
init: 初始值. 如果为 'random' 则随机初始化,否则期望传入数组形式的初始值.
|
|
|
|
|
|
noise: None, str 或 str 列表, 指定要添加的噪声类型. 如果是列表,则混合各噪声 (取平均) 作为混合噪声.
|
|
|
|
|
|
-- 可选 'gaussian'(高斯白噪声)、'colored' 或 '1/f'(色噪声)、'impulse'(脉冲噪声)、'sine'(正弦噪声)
|
|
|
|
|
|
intensity: float, 噪声强度.
|
|
|
|
|
|
h: float, 时间步长 (TimeDelta).
|
|
|
|
|
|
n_timesteps: int, 时间步数.
|
|
|
|
|
|
transient: int, 需丢弃的时间步数.
|
|
|
|
|
|
normlization: bool, 是否对数据进行归一化处理.
|
|
|
|
|
|
kwargs: 其他系统参数.
|
|
|
|
|
|
返回:
|
|
|
|
|
|
(clean_data, noisy_data): 两个 numpy 数组, 分别为干净数据和添加噪声后的数据.
|
|
|
|
|
|
|
|
|
|
|
|
各系统默认值:
|
|
|
|
|
|
- lorenz: rho=28, sigma=10, beta=8/3, h=0.03, x0=[1, 1, 1]
|
|
|
|
|
|
- kuramoto_sivashinsky: N=128, M=16, h=0.25, x0=None
|
|
|
|
|
|
"""
|
|
|
|
|
|
system = system.lower()
|
|
|
|
|
|
if init == 'random':
|
2025-03-27 21:53:48 +08:00
|
|
|
|
if system in ['lorenz']:
|
2025-03-08 20:25:01 +08:00
|
|
|
|
# 默认3维初始值
|
|
|
|
|
|
state = np.random.rand(3)
|
|
|
|
|
|
elif system in ['kuramoto_sivashinsky']:
|
2025-03-27 21:53:48 +08:00
|
|
|
|
state = np.random.rand(kwargs.get('N', 128))
|
2025-03-08 20:25:01 +08:00
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError("未知的混沌系统类型。")
|
|
|
|
|
|
else:
|
|
|
|
|
|
state = np.array(init, dtype=float)
|
|
|
|
|
|
|
|
|
|
|
|
if system == 'lorenz':
|
|
|
|
|
|
'''
|
|
|
|
|
|
(function) def lorenz(
|
|
|
|
|
|
n_timesteps: int,
|
|
|
|
|
|
rho: float = 28,
|
|
|
|
|
|
sigma: float = 10,
|
|
|
|
|
|
beta: float = 8 / 3,
|
|
|
|
|
|
x0: list | ndarray = [1, 1, 1],
|
|
|
|
|
|
h: float = 0.03,
|
|
|
|
|
|
**kwargs: Any
|
|
|
|
|
|
) -> ndarray
|
|
|
|
|
|
'''
|
|
|
|
|
|
sigma = kwargs.get('sigma', 10.0)
|
|
|
|
|
|
rho = kwargs.get('rho', 28.0)
|
|
|
|
|
|
beta = kwargs.get('beta', 8/3)
|
|
|
|
|
|
clean_data = datasets.lorenz(n_timesteps=n_timesteps, h=h, sigma=sigma, rho=rho, beta=beta, x0=state)
|
|
|
|
|
|
|
|
|
|
|
|
elif system == 'kuramoto_sivashinsky':
|
|
|
|
|
|
'''
|
|
|
|
|
|
(function) def kuramoto_sivashinsky(
|
|
|
|
|
|
n_timesteps: int,
|
|
|
|
|
|
warmup: int = 0,
|
|
|
|
|
|
N: int = 128,
|
|
|
|
|
|
M: float = 16,
|
|
|
|
|
|
x0: list | ndarray = None,
|
|
|
|
|
|
h: float = 0.25
|
|
|
|
|
|
) -> ndarray
|
|
|
|
|
|
'''
|
|
|
|
|
|
clean_data = datasets.kuramoto_sivashinsky(n_timesteps=n_timesteps, h=h, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
2025-03-27 21:53:48 +08:00
|
|
|
|
raise ValueError("未知的混沌系统类型。可选项: 'lorenz', 'kuramoto_sivashinsky'")
|
|
|
|
|
|
|
|
|
|
|
|
eps = 1e-8 # 定义一个极小值
|
|
|
|
|
|
normalize_01_eps = lambda data: (data - data.min(axis=0)) / (data.max(axis=0) - data.min(axis=0) + eps)
|
2025-03-08 20:25:01 +08:00
|
|
|
|
|
|
|
|
|
|
if normlization:
|
2025-03-27 21:53:48 +08:00
|
|
|
|
clean_data = normalize_01_eps(clean_data)
|
|
|
|
|
|
print('use new 0-1 norm')
|
|
|
|
|
|
#clean_data = (clean_data - clean_data.mean(axis=0)) / clean_data.std(axis=0) # Z-score 归一化
|
|
|
|
|
|
clean_data = clean_data[transient:,:]
|
2025-03-08 20:25:01 +08:00
|
|
|
|
|
|
|
|
|
|
# 添加噪声
|
|
|
|
|
|
if noise is not None:
|
|
|
|
|
|
if isinstance(noise, list):
|
|
|
|
|
|
noise_sum = np.zeros_like(clean_data)
|
|
|
|
|
|
for nt in noise:
|
|
|
|
|
|
noise_sum += add_noise(np.zeros_like(clean_data), noise_type=nt, intensity=intensity, **kwargs)
|
|
|
|
|
|
mixed_noise = noise_sum / len(noise)
|
|
|
|
|
|
noisy_data = clean_data + mixed_noise
|
|
|
|
|
|
elif isinstance(noise, str):
|
|
|
|
|
|
noisy_data = add_noise(clean_data, noise_type=noise, intensity=intensity, **kwargs)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError("噪声参数 noise 必须为字符串或字符串列表。")
|
|
|
|
|
|
else:
|
|
|
|
|
|
noisy_data = clean_data.copy()
|
|
|
|
|
|
|
|
|
|
|
|
return clean_data, noisy_data
|
|
|
|
|
|
|
2025-03-27 21:53:48 +08:00
|
|
|
|
def get_lorenz(n_timesteps=10000, h=0.01, transient=1000,
|
|
|
|
|
|
noise='gaussian', intensity=0.1, x0=[1, 1, 1],
|
|
|
|
|
|
rho=28, sigma=10, beta=8/3, normlization=True):
|
|
|
|
|
|
"""
|
|
|
|
|
|
生成 Lorenz 系统数据.
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
n_timesteps: int, 时间步数.
|
|
|
|
|
|
h: float, 时间步长.
|
|
|
|
|
|
transient: int, 丢弃的时间步数.
|
|
|
|
|
|
noise: str, 噪声类型.
|
|
|
|
|
|
intensity: float, 噪声强度.
|
|
|
|
|
|
x0: list, 初始值.
|
|
|
|
|
|
rho, sigma, beta: float, 系统参数.
|
|
|
|
|
|
kwargs: 其他参数.
|
|
|
|
|
|
返回:
|
|
|
|
|
|
(clean_data, noisy_data): 两个 numpy 数组, 分别为干净数据和添加噪声后的数据.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 生成 Lorenz 系统数据
|
|
|
|
|
|
clean_data, noisy_data = load_data(system='lorenz', noise=noise, intensity=intensity, n_timesteps=n_timesteps, transient=transient, h=h, x0=x0, rho=rho, sigma=sigma, beta=beta, normlization=True)
|
|
|
|
|
|
return clean_data, noisy_data
|
|
|
|
|
|
|
|
|
|
|
|
def get_kuramoto_sivashinsky(n_timesteps=10000, h=0.25, transient=2000,
|
|
|
|
|
|
noise='gaussian', intensity=0.1, N=128, M=16,
|
|
|
|
|
|
x0=None, normlization=True):
|
|
|
|
|
|
"""
|
|
|
|
|
|
生成 Kuramoto-Sivashinsky 系统数据.
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
n_timesteps: int, 时间步数.
|
|
|
|
|
|
h: float, 时间步长.
|
|
|
|
|
|
transient: int, 丢弃的时间步数.
|
|
|
|
|
|
noise: str, 噪声类型.
|
|
|
|
|
|
intensity: float, 噪声强度.
|
|
|
|
|
|
N: int, 系统参数.
|
|
|
|
|
|
M: float, 系统参数.
|
|
|
|
|
|
x0: list | ndarray, 初始值.
|
|
|
|
|
|
返回:
|
|
|
|
|
|
(clean_data, noisy_data): 两个 numpy 数组, 分别为干净数据和添加噪声后的数据.
|
|
|
|
|
|
"""
|
|
|
|
|
|
clean_data, noisy_data = load_data(system='kuramoto_sivashinsky', noise=noise, intensity=intensity, n_timesteps=n_timesteps, transient=transient, h=h, N=N, M=M, x0=x0, normlization=True)
|
|
|
|
|
|
return clean_data, noisy_data
|
|
|
|
|
|
|
|
|
|
|
|
def get_wind(noise='gaussian', intensity=0.1, filepath='exp/data/scale_windspeed.txt', normlization=True):
|
|
|
|
|
|
'''
|
|
|
|
|
|
读取风速数据并添加噪声.
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
noise: str, 噪声类型.
|
|
|
|
|
|
intensity: float, 噪声强度.
|
|
|
|
|
|
filepath: str, 数据文件路径.
|
|
|
|
|
|
返回:
|
|
|
|
|
|
wind_data: numpy 数组, 添加噪声后的风速数据.
|
|
|
|
|
|
'''
|
|
|
|
|
|
wind_data = np.loadtxt(filepath).T # shape -> [155, 157819]
|
|
|
|
|
|
|
|
|
|
|
|
if normlization:
|
|
|
|
|
|
wind_data = (wind_data - wind_data.min(axis=0)) / (wind_data.max(axis=0) - wind_data.min(axis=0) + 1e-8)
|
|
|
|
|
|
# wind_data = (wind_data - wind_data.mean(axis=0)) / wind_data.std(axis=0) # Z-score 归一化
|
|
|
|
|
|
|
|
|
|
|
|
if noise is not None:
|
|
|
|
|
|
# 添加噪声
|
|
|
|
|
|
wind_noisy = add_noise(wind_data, noise_type=noise, intensity=intensity)
|
|
|
|
|
|
|
|
|
|
|
|
return wind_data, wind_noisy
|
|
|
|
|
|
|
2025-03-11 21:49:50 +08:00
|
|
|
|
def visualize_data_diff(clean_data, noisy_data, prediction, warmup=0, title_prefix=""):
|
2025-03-27 21:53:48 +08:00
|
|
|
|
"""
|
|
|
|
|
|
可视化数据比较:噪声数据与干净数据,预测数据与干净数据
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
clean_data: 干净数据
|
|
|
|
|
|
noisy_data: 噪声数据
|
|
|
|
|
|
prediction: 预测数据
|
|
|
|
|
|
warmup: 起始索引
|
|
|
|
|
|
title_prefix: 标题前缀
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 获取数据维度
|
|
|
|
|
|
n_dims = 1 if clean_data.ndim == 1 else clean_data.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
# 调整图形大小和布局
|
|
|
|
|
|
if n_dims <= 3:
|
|
|
|
|
|
fig_height = 4 * n_dims
|
|
|
|
|
|
fig_width = 12
|
|
|
|
|
|
n_cols = 2
|
|
|
|
|
|
else:
|
|
|
|
|
|
fig_height = 3 * ((n_dims + 1) // 2)
|
|
|
|
|
|
fig_width = 14
|
|
|
|
|
|
n_cols = 4
|
|
|
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(fig_width, fig_height))
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(n_dims):
|
|
|
|
|
|
# 获取当前维度的数据
|
|
|
|
|
|
if clean_data.ndim == 1:
|
|
|
|
|
|
clean_dim = clean_data[warmup:]
|
|
|
|
|
|
noisy_dim = noisy_data[warmup:]
|
|
|
|
|
|
pred_dim = prediction[warmup:]
|
|
|
|
|
|
else:
|
|
|
|
|
|
clean_dim = clean_data[warmup:, i]
|
|
|
|
|
|
noisy_dim = noisy_data[warmup:, i]
|
|
|
|
|
|
pred_dim = prediction[warmup:, i]
|
|
|
|
|
|
|
|
|
|
|
|
# 噪声数据与干净数据对比
|
|
|
|
|
|
plt.subplot(n_dims, n_cols, n_cols*i+1)
|
|
|
|
|
|
plt.plot(noisy_dim, label='noisy')
|
|
|
|
|
|
plt.plot(clean_dim, label='clean')
|
2025-03-11 21:49:50 +08:00
|
|
|
|
plt.title(f"{title_prefix} Dim {i+1} (Noisy vs Clean)")
|
2025-03-27 21:53:48 +08:00
|
|
|
|
if i == 0 or n_dims <= 3:
|
|
|
|
|
|
plt.legend()
|
|
|
|
|
|
|
|
|
|
|
|
# 预测数据与干净数据对比
|
|
|
|
|
|
plt.subplot(n_dims, n_cols, n_cols*i+2)
|
|
|
|
|
|
plt.plot(pred_dim, label='prediction')
|
|
|
|
|
|
plt.plot(clean_dim, label='clean')
|
2025-03-11 21:49:50 +08:00
|
|
|
|
plt.title(f"{title_prefix} Dim {i+1} (Prediction vs Clean)")
|
2025-03-27 21:53:48 +08:00
|
|
|
|
if i == 0 or n_dims <= 3:
|
|
|
|
|
|
plt.legend()
|
|
|
|
|
|
|
2025-03-11 21:49:50 +08:00
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_psd_diff(clean_data, noisy_data, prediction, warmup=0, title_prefix=""):
|
2025-03-27 21:53:48 +08:00
|
|
|
|
"""
|
|
|
|
|
|
可视化数据的功率谱密度比较
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
clean_data: 干净数据
|
|
|
|
|
|
noisy_data: 噪声数据
|
|
|
|
|
|
prediction: 预测数据
|
|
|
|
|
|
warmup: 起始索引
|
|
|
|
|
|
title_prefix: 标题前缀
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 获取数据维度
|
|
|
|
|
|
n_dims = 1 if clean_data.ndim == 1 else clean_data.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
# 调整图形大小
|
|
|
|
|
|
fig_height = 4 * min(n_dims, 5) # 限制最大高度
|
|
|
|
|
|
fig_width = 12
|
|
|
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(fig_width, fig_height))
|
|
|
|
|
|
|
|
|
|
|
|
# 如果维度过多,只显示前5个维度
|
|
|
|
|
|
show_dims = min(n_dims, 5)
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(show_dims):
|
|
|
|
|
|
plt.subplot(show_dims, 1, i+1)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取当前维度的数据
|
|
|
|
|
|
if clean_data.ndim == 1:
|
|
|
|
|
|
clean_dim = clean_data[warmup:]
|
|
|
|
|
|
noisy_dim = noisy_data[warmup:]
|
|
|
|
|
|
pred_dim = prediction[warmup:]
|
|
|
|
|
|
else:
|
|
|
|
|
|
clean_dim = clean_data[warmup:, i]
|
|
|
|
|
|
noisy_dim = noisy_data[warmup:, i]
|
|
|
|
|
|
pred_dim = prediction[warmup:, i]
|
|
|
|
|
|
|
|
|
|
|
|
plt.psd(clean_dim, NFFT=1024, label='clean')
|
|
|
|
|
|
plt.psd(noisy_dim, NFFT=1024, label='noisy')
|
|
|
|
|
|
plt.psd(pred_dim, NFFT=1024, label='prediction')
|
2025-03-11 21:49:50 +08:00
|
|
|
|
plt.title(f"{title_prefix} Dim {i+1} PSD Comparison")
|
|
|
|
|
|
plt.legend()
|
2025-03-27 21:53:48 +08:00
|
|
|
|
|
2025-03-11 21:49:50 +08:00
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
2025-03-27 21:53:48 +08:00
|
|
|
|
def visualize_ks_denoising(clean_data, noisy_data, prediction, warmup=0, title_prefix="", time_slice=None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
可视化Kuramoto-Sivashinsky系统的降噪结果 (修改版)
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
clean_data: 干净数据,形状为[时间步数, 空间点数]
|
|
|
|
|
|
noisy_data: 噪声数据,形状同上
|
|
|
|
|
|
prediction: 预测数据,形状同上
|
|
|
|
|
|
warmup: 要从数据开头移除的时间步数
|
|
|
|
|
|
title_prefix: 图形总标题的前缀
|
|
|
|
|
|
time_slice: 要显示的特定时间索引列表 (最多3个),如果为None则自动选择
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. 数据准备
|
|
|
|
|
|
if warmup >= clean_data.shape[0]:
|
|
|
|
|
|
print(f"Warning: Warmup period ({warmup}) is longer than or equal to data length ({clean_data.shape[0]}). No data to plot.")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
clean = clean_data[warmup:]
|
|
|
|
|
|
noisy = noisy_data[warmup:]
|
|
|
|
|
|
pred = prediction[warmup:]
|
|
|
|
|
|
|
|
|
|
|
|
if clean.shape[0] == 0:
|
|
|
|
|
|
print("Warning: Data is empty after removing warmup period.")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 计算误差和MSE
|
|
|
|
|
|
mse_noisy = np.mean((clean - noisy) ** 2)
|
|
|
|
|
|
mse_pred = np.mean((clean - pred) ** 2)
|
|
|
|
|
|
diff_noisy = clean - noisy
|
|
|
|
|
|
diff_pred = clean - pred
|
|
|
|
|
|
|
|
|
|
|
|
# 为差异图确定对称的颜色范围
|
|
|
|
|
|
max_abs_diff = max(np.abs(diff_noisy).max(), np.abs(diff_pred).max())
|
|
|
|
|
|
# 防止 max_abs_diff 为 0 导致 vmin=vmax
|
|
|
|
|
|
if max_abs_diff < 1e-9:
|
|
|
|
|
|
max_abs_diff = 1e-9
|
|
|
|
|
|
diff_vmin, diff_vmax = -max_abs_diff, max_abs_diff
|
|
|
|
|
|
|
|
|
|
|
|
# 为数据图确定共享的颜色范围 (可选,但有助于比较)
|
|
|
|
|
|
data_vmin = min(clean.min(), noisy.min(), pred.min())
|
|
|
|
|
|
data_vmax = max(clean.max(), noisy.max(), pred.max())
|
|
|
|
|
|
if data_vmax - data_vmin < 1e-9: # Handle constant data case
|
|
|
|
|
|
data_vmax = data_vmin + 1e-9
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 创建图形和轴 (4x3 网格)
|
|
|
|
|
|
fig, axes = plt.subplots(4, 3, figsize=(15, 17)) # 调整 figsize 以适应 4 行
|
|
|
|
|
|
fig.suptitle(f"{title_prefix} Kuramoto-Sivashinsky Denoising Results", fontsize=16)
|
|
|
|
|
|
|
|
|
|
|
|
# imshow 的通用范围
|
|
|
|
|
|
extent = [0, clean.shape[0], 0, clean.shape[1]] # [时间_min, 时间_max, 空间_min, 空间_max]
|
|
|
|
|
|
time_label = 'Time Step (after warmup)'
|
|
|
|
|
|
space_label = 'Space'
|
|
|
|
|
|
amplitude_label = 'Amplitude'
|
|
|
|
|
|
difference_label = 'Difference'
|
|
|
|
|
|
|
|
|
|
|
|
# --- 第 1 行: 干净, 噪声, 干净 - 噪声 ---
|
|
|
|
|
|
im1 = axes[0, 0].imshow(clean.T, aspect='auto', cmap='viridis', extent=extent, vmin=data_vmin, vmax=data_vmax)
|
|
|
|
|
|
fig.colorbar(im1, ax=axes[0, 0], label=amplitude_label)
|
|
|
|
|
|
axes[0, 0].set_title("Clean Data")
|
|
|
|
|
|
axes[0, 0].set_ylabel(space_label)
|
|
|
|
|
|
|
|
|
|
|
|
im2 = axes[0, 1].imshow(noisy.T, aspect='auto', cmap='viridis', extent=extent, vmin=data_vmin, vmax=data_vmax)
|
|
|
|
|
|
fig.colorbar(im2, ax=axes[0, 1], label=amplitude_label)
|
|
|
|
|
|
axes[0, 1].set_title(f"Noisy Data (MSE: {mse_noisy:.4f})")
|
|
|
|
|
|
# axes[0, 1].set_ylabel(space_label) # Y轴标签通常只在最左侧显示
|
|
|
|
|
|
|
|
|
|
|
|
im3 = axes[0, 2].imshow(diff_noisy.T, aspect='auto', cmap='coolwarm', extent=extent, vmin=diff_vmin, vmax=diff_vmax)
|
|
|
|
|
|
fig.colorbar(im3, ax=axes[0, 2], label=difference_label)
|
|
|
|
|
|
axes[0, 2].set_title("Clean - Noisy")
|
|
|
|
|
|
# axes[0, 2].set_ylabel(space_label)
|
|
|
|
|
|
|
|
|
|
|
|
# --- 第 2 行: 干净, 预测, 干净 - 预测 ---
|
|
|
|
|
|
im4 = axes[1, 0].imshow(clean.T, aspect='auto', cmap='viridis', extent=extent, vmin=data_vmin, vmax=data_vmax)
|
|
|
|
|
|
fig.colorbar(im4, ax=axes[1, 0], label=amplitude_label)
|
|
|
|
|
|
axes[1, 0].set_title("Clean Data")
|
|
|
|
|
|
axes[1, 0].set_ylabel(space_label)
|
|
|
|
|
|
axes[1, 0].set_xlabel(time_label) # 在底部imshow行添加时间标签
|
|
|
|
|
|
|
|
|
|
|
|
im5 = axes[1, 1].imshow(pred.T, aspect='auto', cmap='viridis', extent=extent, vmin=data_vmin, vmax=data_vmax)
|
|
|
|
|
|
fig.colorbar(im5, ax=axes[1, 1], label=amplitude_label)
|
|
|
|
|
|
axes[1, 1].set_title(f"Prediction (MSE: {mse_pred:.4f})")
|
|
|
|
|
|
# axes[1, 1].set_ylabel(space_label)
|
|
|
|
|
|
axes[1, 1].set_xlabel(time_label)
|
|
|
|
|
|
|
|
|
|
|
|
im6 = axes[1, 2].imshow(diff_pred.T, aspect='auto', cmap='coolwarm', extent=extent, vmin=diff_vmin, vmax=diff_vmax)
|
|
|
|
|
|
fig.colorbar(im6, ax=axes[1, 2], label=difference_label)
|
|
|
|
|
|
axes[1, 2].set_title("Clean - Prediction")
|
|
|
|
|
|
# axes[1, 2].set_ylabel(space_label)
|
|
|
|
|
|
axes[1, 2].set_xlabel(time_label)
|
|
|
|
|
|
|
|
|
|
|
|
# --- 第 3 行: 时间切片 ---
|
|
|
|
|
|
num_steps = clean.shape[0]
|
|
|
|
|
|
if time_slice is None:
|
|
|
|
|
|
# 自动选择最多3个均匀分布的时间点
|
|
|
|
|
|
if num_steps > 0:
|
|
|
|
|
|
# 确保 t_points 是有效的索引,并且尽可能均匀
|
|
|
|
|
|
indices = np.linspace(0, num_steps - 1, 5, dtype=int)[1:-1] # 取中间的3个点
|
|
|
|
|
|
t_points = sorted(list(set(indices))) # 去重并排序
|
|
|
|
|
|
# 如果点数少于3个(因为数据短或linspace结果重复),补充点
|
|
|
|
|
|
while len(t_points) < 3 and len(t_points) < num_steps:
|
|
|
|
|
|
if 0 not in t_points: t_points.insert(0, 0)
|
|
|
|
|
|
elif num_steps - 1 not in t_points: t_points.append(num_steps - 1)
|
|
|
|
|
|
else: # 如果头尾都有了,尝试在中间加
|
|
|
|
|
|
mid_ish = num_steps // 2
|
|
|
|
|
|
if mid_ish not in t_points: t_points.append(mid_ish)
|
|
|
|
|
|
else: break # 无法添加更多唯一的点
|
|
|
|
|
|
t_points = sorted(list(set(t_points)))[:3] # 保持最多3个
|
|
|
|
|
|
else:
|
|
|
|
|
|
t_points = [] # 如果没有时间步,则没有切片
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 过滤用户提供的时间切片以确保有效性,并最多取前3个
|
|
|
|
|
|
t_points = [t for t in time_slice if 0 <= t < num_steps]
|
|
|
|
|
|
if len(t_points) > 3:
|
|
|
|
|
|
print(f"Warning: Provided {len(time_slice)} time slices. Using the first 3 valid ones: {t_points[:3]}")
|
|
|
|
|
|
t_points = t_points[:3]
|
|
|
|
|
|
elif not t_points and time_slice:
|
|
|
|
|
|
print(f"Warning: None of the provided time slices {time_slice} are valid for data length {num_steps}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制时间切片图
|
|
|
|
|
|
num_slice_plots = 3
|
|
|
|
|
|
for i in range(num_slice_plots):
|
|
|
|
|
|
ax_slice = axes[2, i]
|
|
|
|
|
|
if i < len(t_points):
|
|
|
|
|
|
t = t_points[i]
|
|
|
|
|
|
ax_slice.plot(clean[t], label='Clean')
|
|
|
|
|
|
ax_slice.plot(noisy[t], label='Noisy', alpha=0.7)
|
|
|
|
|
|
ax_slice.plot(pred[t], label='Prediction', linestyle='--')
|
|
|
|
|
|
ax_slice.set_title(f"Time slice at t={t}")
|
|
|
|
|
|
ax_slice.set_xlabel(space_label)
|
|
|
|
|
|
if i == 0: ax_slice.set_ylabel(amplitude_label) # 只在最左侧显示Y轴标签
|
|
|
|
|
|
ax_slice.legend()
|
|
|
|
|
|
ax_slice.grid(True, linestyle='--', alpha=0.6)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果没有足够的时间点来绘制,则隐藏多余的子图
|
|
|
|
|
|
ax_slice.axis('off')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# --- 第 4 行: 误差分布 & 随时间变化的 MSE ---
|
|
|
|
|
|
# 误差直方图
|
|
|
|
|
|
bins = 50
|
|
|
|
|
|
ax_hist_noisy = axes[3, 0]
|
|
|
|
|
|
counts_noisy, _, _ = ax_hist_noisy.hist(diff_noisy.flatten(), bins=bins, alpha=0.7, label='Noisy Error')
|
|
|
|
|
|
ax_hist_noisy.set_title('Noise Error Distribution')
|
|
|
|
|
|
ax_hist_noisy.set_xlabel('Error (Clean - Noisy)')
|
|
|
|
|
|
ax_hist_noisy.set_ylabel('Count')
|
|
|
|
|
|
ax_hist_noisy.grid(True, linestyle='--', alpha=0.6)
|
|
|
|
|
|
|
|
|
|
|
|
ax_hist_pred = axes[3, 1]
|
|
|
|
|
|
counts_pred, _, _ = ax_hist_pred.hist(diff_pred.flatten(), bins=bins, alpha=0.7, label='Prediction Error', color='orange')
|
|
|
|
|
|
ax_hist_pred.set_title('Prediction Error Distribution')
|
|
|
|
|
|
ax_hist_pred.set_xlabel('Error (Clean - Prediction)')
|
|
|
|
|
|
# ax_hist_pred.set_ylabel('Count') # Y轴标签通常只在最左侧显示
|
|
|
|
|
|
|
|
|
|
|
|
# 为直方图设置共享的 y 轴限制
|
|
|
|
|
|
max_y = 0
|
|
|
|
|
|
if counts_noisy.size > 0: # 检查是否有计数
|
|
|
|
|
|
max_y = max(max_y, np.max(counts_noisy))
|
|
|
|
|
|
if counts_pred.size > 0: # 检查是否有计数
|
|
|
|
|
|
max_y = max(max_y, np.max(counts_pred))
|
|
|
|
|
|
|
|
|
|
|
|
if max_y > 0: # 仅当绘制了直方图时才设置 ylim
|
|
|
|
|
|
common_ylim = (0, max_y * 1.1) # 增加 10% 的空隙
|
|
|
|
|
|
ax_hist_noisy.set_ylim(common_ylim)
|
|
|
|
|
|
ax_hist_pred.set_ylim(common_ylim)
|
|
|
|
|
|
else: # 处理数据为空或恒定的情况
|
|
|
|
|
|
ax_hist_noisy.set_ylim(0, 1)
|
|
|
|
|
|
ax_hist_pred.set_ylim(0, 1)
|
|
|
|
|
|
|
|
|
|
|
|
# 随时间变化的 MSE
|
|
|
|
|
|
ax_mse = axes[3, 2]
|
|
|
|
|
|
time_axis = np.arange(num_steps) # 创建时间轴
|
|
|
|
|
|
ax_mse.plot(time_axis, np.mean(diff_noisy**2, axis=1), label='Noisy MSE')
|
|
|
|
|
|
ax_mse.plot(time_axis, np.mean(diff_pred**2, axis=1), label='Prediction MSE')
|
|
|
|
|
|
ax_mse.set_title('MSE over Time')
|
|
|
|
|
|
ax_mse.set_xlabel(time_label)
|
|
|
|
|
|
ax_mse.set_ylabel('MSE')
|
|
|
|
|
|
ax_mse.set_yscale('log') # MSE 通常最好用对数刻度显示
|
|
|
|
|
|
ax_mse.legend()
|
|
|
|
|
|
ax_mse.grid(True, linestyle='--', alpha=0.6)
|
|
|
|
|
|
|
|
|
|
|
|
# --- 最终调整 ---
|
|
|
|
|
|
plt.tight_layout(rect=[0, 0.01, 1, 0.96]) # 调整布局以适应 suptitle 和 x 标签
|
|
|
|
|
|
# plt.show() # 如果需要立即显示图形,取消注释此行
|
|
|
|
|
|
|
2025-03-08 20:25:01 +08:00
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
# 测试数据加载和噪声添加
|
2025-03-27 21:53:48 +08:00
|
|
|
|
clean_data, noisy_data = get_kuramoto_sivashinsky()
|
2025-03-08 20:25:01 +08:00
|
|
|
|
print("Clean Data Shape:", clean_data.shape)
|
2025-03-27 21:53:48 +08:00
|
|
|
|
print("Noisy Data Shape:", noisy_data.shape)
|
|
|
|
|
|
visualize_ks_denoising(clean_data, noisy_data, noisy_data)
|