denoise_RC/exp/exp1_basic.py
2025-03-22 14:02:38 +08:00

105 lines
4.9 KiB
Python

"""
实验1: 基础降噪与动力学特性恢复实验 (核心实验)
目标:
验证 RC 降噪方法在降噪和动力学特性恢复方面的基本有效性,建立基准性能。
实验设置:
- 混沌系统: 例如 Lorenz 或 Rossler 系统。
- 噪声类型: 高斯白噪声 (可扩展到其他噪声)。
- 噪声水平: 不同 SNR 设置。
- RC 模型: 使用 ESN 等 RC 模型,调优储备池大小、谱半径等参数。
- 训练方法: 监督学习 (例如 Ridge 回归)。
评估指标:
- 降噪指标: SNR 提升、均方根误差 (RMSE)、相关系数、峰值信噪比。
- 动力学特性恢复指标: 吸引子重构质量、Lyapunov 指数估计误差、分形维数估计误差、预测精度。
结果展示:
- 可视化时间序列、吸引子重构、PSD 对比图和表格展示指标结果。
"""
import numpy as np
import matplotlib.pyplot as plt
import reservoirpy as rpy
import reservoirpy.nodes as rpn
rpy.verbosity(0)
rpy.set_seed(42)
from tools import load_data, visualize_data_diff, visualize_psd_diff
def build_esn_model(units, lr, sr, reg, output_dim, reservoir_node=None):
if reservoir_node is None:
reservoir_node = rpn.Reservoir(units=units, lr=lr, sr=sr)
output_node = rpn.Ridge(output_dim=output_dim, ridge=reg)
model = reservoir_node >> output_node
return model, reservoir_node
def train_model(model, clean_data, noisy_data, warmup=1000):
model.fit(noisy_data, clean_data, warmup=warmup)
prediction = model.run(noisy_data)
return model, prediction
def evaluate_results(model, clean_data, noisy_data, warmup=0, vis_data=True, vis_psd=False):
# 预测
prediction = model.run(noisy_data)
# 计算各种指标
snr = 10 * np.log10(np.mean(clean_data**2) / np.mean((clean_data - prediction)**2))
rmse = np.sqrt(np.mean((clean_data - prediction)**2))
psnr = 20 * np.log10(np.max(clean_data) / rmse)
corr_coef = np.corrcoef(clean_data, prediction)[0, 1]
# 3个维度分别绘制 noisy vs clean | prediction vs clean
if vis_data:
visualize_data_diff(clean_data, noisy_data, prediction, warmup=0)
# PSD 对比
if vis_psd:
visualize_psd_diff(clean_data, noisy_data, prediction, warmup=0)
return snr, rmse, psnr, corr_coef, prediction
def run():
# 数据加载与预处理
clean_data, noisy_data = load_data(system='lorenz', noise='gaussian',
intensity=0.5, init=[1, 1, 1],
n_timesteps=40000, transient=10000, h=0.01)
# 分为训练集和测试集 按百分比
train_size = int(len(clean_data) * 0.8)
train_clean_data = clean_data[:train_size]
train_noisy_data = noisy_data[:train_size]
test_clean_data = clean_data[train_size:]
test_noisy_data = noisy_data[train_size:]
# LV1
model_lv1, reservoir_node = build_esn_model(units=1000, lr=0.1, sr=0.9, reg=1e-5, output_dim=3)
model_lv1, lv1_train_pred = train_model(model_lv1, train_clean_data, train_noisy_data)
snr, rmse, psnr, corr_coef, lv1_pred = evaluate_results(model_lv1, test_clean_data, test_noisy_data)
print(f"SNR: {snr:.6f} RMSE: {rmse:.6f} PSNR: {psnr:.6f} CorrCoef: {corr_coef:.6f}")
# LV2
model_lv2, _ = build_esn_model(units=1000, lr=0.2, sr=0.9, reg=1e-5, output_dim=3, reservoir_node=reservoir_node)
model_lv2, lv2_train_pred = train_model(model_lv2, train_clean_data, lv1_train_pred)
snr, rmse, psnr, corr_coef, lv2_pred = evaluate_results(model_lv2, test_clean_data, lv1_pred)
print(f"SNR: {snr:.6f} RMSE: {rmse:.6f} PSNR: {psnr:.6f} CorrCoef: {corr_coef:.6f}")
# LV3
model_lv3, _ = build_esn_model(units=1000, lr=0.3, sr=0.9, reg=1e-5, output_dim=3, reservoir_node=reservoir_node)
model_lv3, lv3_train_pred = train_model(model_lv3, train_clean_data, lv2_train_pred)
snr, rmse, psnr, corr_coef, lv3_pred = evaluate_results(model_lv3, test_clean_data, lv2_pred)
print(f"SNR: {snr:.6f} RMSE: {rmse:.6f} PSNR: {psnr:.6f} CorrCoef: {corr_coef:.6f}")
# LV4
model_lv4, _ = build_esn_model(units=1000, lr=0.4, sr=0.9, reg=1e-4, output_dim=3, reservoir_node=reservoir_node)
model_lv4, lv4_train_pred = train_model(model_lv4, train_clean_data, lv3_train_pred)
snr, rmse, psnr, corr_coef, lv4_pred = evaluate_results(model_lv4, test_clean_data, lv3_pred)
print(f"SNR: {snr:.6f} RMSE: {rmse:.6f} PSNR: {psnr:.6f} CorrCoef: {corr_coef:.6f}")
# LV5
model_lv5, _ = build_esn_model(units=1000, lr=0.5, sr=0.9, reg=1e-3, output_dim=3, reservoir_node=reservoir_node)
model_lv5, lv5_train_pred = train_model(model_lv5, train_clean_data, lv4_train_pred)
snr, rmse, psnr, corr_coef, lv5_pred = evaluate_results(model_lv5, test_clean_data, lv4_pred)
print(f"SNR: {snr:.6f} RMSE: {rmse:.6f} PSNR: {psnr:.6f} CorrCoef: {corr_coef:.6f}")
plt.show()
if __name__ == "__main__":
run()