change exp design
This commit is contained in:
parent
48acf57452
commit
5cf5e5e2ab
1371
.ipynb_checkpoints/denoise many try-checkpoint.ipynb
Normal file
1371
.ipynb_checkpoints/denoise many try-checkpoint.ipynb
Normal file
File diff suppressed because one or more lines are too long
532
.ipynb_checkpoints/denoise one test-checkpoint.ipynb
Normal file
532
.ipynb_checkpoints/denoise one test-checkpoint.ipynb
Normal file
@ -0,0 +1,532 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import reservoirpy as rpy\n",
|
||||||
|
"import reservoirpy.nodes as rpn\n",
|
||||||
|
"from reservoirpy.datasets import lorenz, rossler, doublescroll, kuramoto_sivashinsky\n",
|
||||||
|
"\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from matplotlib import pyplot as plt\n",
|
||||||
|
"\n",
|
||||||
|
"rpy.verbosity(False)\n",
|
||||||
|
"\n",
|
||||||
|
"rpy.set_seed(42)\n",
|
||||||
|
"\n",
|
||||||
|
"name_idx = 0"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def add_noise(time_series, noise_type=\"random\", noise_level=0.5):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" 对输入的时间序列添加噪声。\n",
|
||||||
|
" \n",
|
||||||
|
" 参数:\n",
|
||||||
|
" time_series (numpy.ndarray): 输入的时间序列。\n",
|
||||||
|
" noise_type (str): 噪声类型,可选值为 \"random\"(随机噪声), \"sin\"(正弦噪声), 或 \"gaussian\"(正态分布噪声)。\n",
|
||||||
|
" noise_level (float): 噪声强度,决定噪声幅度。\n",
|
||||||
|
" \n",
|
||||||
|
" 返回:\n",
|
||||||
|
" numpy.ndarray: 添加噪声后的时间序列。\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" if noise_type == \"random\":\n",
|
||||||
|
" noise = np.random.uniform(-noise_level, noise_level, len(time_series))\n",
|
||||||
|
" elif noise_type == \"sin\":\n",
|
||||||
|
" noise = noise_level * np.sin(30 * np.pi * np.arange(len(time_series)) / len(time_series))\n",
|
||||||
|
" elif noise_type == \"gaussian\":\n",
|
||||||
|
" noise = np.random.normal(0, noise_level, len(time_series))\n",
|
||||||
|
" else:\n",
|
||||||
|
" raise ValueError(\"Unsupported noise_type. Choose from 'random', 'sin', or 'gaussian'.\")\n",
|
||||||
|
" \n",
|
||||||
|
" return time_series + np.stack((noise,noise,noise)).T"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def show(t1, t2, figsize=(18, 6), title=\"Lorenz Attractor\"):\n",
|
||||||
|
" '''\n",
|
||||||
|
" t1: Noisy data(plot with solid line & orange color)\n",
|
||||||
|
" \n",
|
||||||
|
" t2: Clean data(plot with dashed line & blue color)\n",
|
||||||
|
" '''\n",
|
||||||
|
" plt.figure(figsize=figsize)\n",
|
||||||
|
"\n",
|
||||||
|
" plt.subplot(3, 1, 1)\n",
|
||||||
|
" #plt.plot(t2[:,0])\n",
|
||||||
|
" plt.plot(t1[:,0], color='#ff7f0e')\n",
|
||||||
|
" plt.plot(t2[:,0], linestyle='--')\n",
|
||||||
|
" plt.title('X-component')\n",
|
||||||
|
"\n",
|
||||||
|
" plt.subplot(3, 1, 2)\n",
|
||||||
|
" #plt.plot(t2[:,1])\n",
|
||||||
|
" plt.plot(t1[:,1], color='#ff7f0e')\n",
|
||||||
|
" plt.plot(t2[:,1], linestyle='--')\n",
|
||||||
|
" plt.title('Y-component')\n",
|
||||||
|
"\n",
|
||||||
|
" plt.subplot(3, 1, 3)\n",
|
||||||
|
" #plt.plot(t2[:,2])\n",
|
||||||
|
" plt.plot(t1[:,2], color='#ff7f0e')\n",
|
||||||
|
" plt.plot(t2[:,2], linestyle='--')\n",
|
||||||
|
" plt.title('Z-component')\n",
|
||||||
|
" \n",
|
||||||
|
" fig = plt.figure(figsize=(12,5), dpi=150)\n",
|
||||||
|
" \n",
|
||||||
|
" rmse = np.sqrt(np.mean((t1 - t2)**2))\n",
|
||||||
|
" fig.suptitle(f\"{title}\\nRMSE: {rmse:.8f}\")\n",
|
||||||
|
" \n",
|
||||||
|
" ax = fig.add_subplot(121, projection='3d')\n",
|
||||||
|
" ax.plot(t2[:,0], t2[:,1], t2[:,2])\n",
|
||||||
|
" ax.set_title(\"Raw Data\")\n",
|
||||||
|
"\n",
|
||||||
|
" ax = fig.add_subplot(122, projection='3d')\n",
|
||||||
|
" ax.plot(t1[:,0], t1[:,1], t1[:,2], color='#ff7f0e')\n",
|
||||||
|
" ax.set_title(\"Processed Data\")\n",
|
||||||
|
" \n",
|
||||||
|
" plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Data acquisition"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"norm01 = lambda x: (x - np.min(x, axis=0)) / (np.max(x, axis=0) - np.min(x, axis=0))\n",
|
||||||
|
"\n",
|
||||||
|
"lorenz_data = norm01(lorenz(12000, h=0.01)[2000:,:])\n",
|
||||||
|
"rossler_data = norm01(rossler(12000, h=0.02)[2000:,:])\n",
|
||||||
|
"doublescroll_data = norm01(doublescroll(12000, h=0.1)[2000:,:])\n",
|
||||||
|
"\n",
|
||||||
|
"noisy_lorenz_data = add_noise(lorenz_data, noise_type=\"gaussian\", noise_level=0.8)\n",
|
||||||
|
"noisy_rossler_data = add_noise(rossler_data, noise_type=\"gaussian\", noise_level=0.8)\n",
|
||||||
|
"noisy_doublescroll_data = add_noise(doublescroll_data, noise_type=\"gaussian\", noise_level=0.8)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"show(noisy_lorenz_data, lorenz_data, title=\"Lorenz Raw Data & Noisy Data\")\n",
|
||||||
|
"show(noisy_rossler_data, rossler_data, title=\"Rossler Raw Data & Noisy Data\")\n",
|
||||||
|
"show(noisy_doublescroll_data, doublescroll_data, title=\"Doublescroll Raw Data & Noisy Data\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Lorenz Test Case"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"lorenz_train_input = noisy_lorenz_data[:7000]\n",
|
||||||
|
"lorenz_train_target = lorenz_data[:7000]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"lorenz_res = rpn.Reservoir(units=1000, lr=0.5, sr=0.9, activation='tanh', equation='external')\n",
|
||||||
|
"lorenz_readout = rpn.Ridge(ridge=5e-3)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# train readout\n",
|
||||||
|
"states = lorenz_res.run(lorenz_train_input)\n",
|
||||||
|
"lorenz_readout.fit(states, lorenz_train_target)\n",
|
||||||
|
"output = lorenz_readout.run(states)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Show the Training"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(f'RMSE of output: {np.sqrt(np.mean((output - lorenz_train_target)**2))}')\n",
|
||||||
|
"show(output, lorenz_train_target, title=\"Lorenz RC Output\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"lorenz_test_input = noisy_lorenz_data[7000:]\n",
|
||||||
|
"lorenz_test_target = lorenz_data[7000:]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# test\n",
|
||||||
|
"states = lorenz_res.run(lorenz_test_input)\n",
|
||||||
|
"output = lorenz_readout1.run(states)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Show the Prediction"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(f'RMSE of output: {np.sqrt(np.mean((output - lorenz_test_target)**2))}')\n",
|
||||||
|
"show(output1, lorenz_test_target, title=\"Lorenz RC Output\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### 迁移学习"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"new_lorenz_data = norm01(lorenz(12000, h=0.01, x0=[0.1, 0.2, 0.3])[2000:,:])\n",
|
||||||
|
"noisy_new_lorenz_data = add_noise(new_lorenz_data, noise_type=\"gaussian\", noise_level=0.8)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"new_lorenz_test_input = noisy_new_lorenz_data\n",
|
||||||
|
"new_lorenz_test_target = new_lorenz_data"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"states = lorenz_res.run(new_lorenz_test_input)\n",
|
||||||
|
"output = lorenz_readout.run(states)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"show(output, new_lorenz_test_target, title=\"Lorenz RC Output\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Rossler Test Case"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"rossler_train_input = noisy_rossler_data[:7000]\n",
|
||||||
|
"rossler_train_target = rossler_data[:7000]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"rossler_input = rpn.Input()\n",
|
||||||
|
"rossler_res = rpn.Reservoir(units=1000, lr=0.5, sr=0.9, activation='tanh', equation='external')\n",
|
||||||
|
"rossler_readout = rpn.Ridge(ridge=5e-3)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"states = rossler_res.run(rossler_train_input)\n",
|
||||||
|
"rossler_readou1.fit(states, rossler_train_target)\n",
|
||||||
|
"output = rossler_readout.run(states)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"show(rossler_train_input, rossler_train_target, title=\"Rossler Raw Data & Noisy Data\")\n",
|
||||||
|
"show(output, rossler_train_target, title=\"Rossler RC Output\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"rossler_test_input = noisy_rossler_data[7000:]\n",
|
||||||
|
"rossler_test_target = rossler_data[7000:]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"states = rossler_res.run(rossler_test_input)\n",
|
||||||
|
"output = rossler_readout.run(states)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"show(output, rossler_test_target, title=\"Rossler RC Output\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"new_rossler_data = norm01(rossler(12000, h=0.02, x0=[0.1, 0.2, 0.3])[2000:,:])\n",
|
||||||
|
"noisy_new_rossler_data = add_noise(new_rossler_data, noise_type=\"gaussian\", noise_level=0.8)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"new_rossler_test_input = noisy_new_rossler_data\n",
|
||||||
|
"new_rossler_test_target = new_rossler_data"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"states = rossler_res.run(new_rossler_test_input)\n",
|
||||||
|
"output = rossler_readout.run(states)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"show(output, new_rossler_test_target, title=\"Rossler RC Output\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Double-scroll test case"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"doublescroll_train_input = noisy_doublescroll_data[:7000]\n",
|
||||||
|
"doublescroll_train_target = doublescroll_data[:7000]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"doublescroll_input = rpn.Input()\n",
|
||||||
|
"doublescroll_res = rpn.Reservoir(units=1000, lr=0.5, sr=0.9, activation='tanh', equation='external')\n",
|
||||||
|
"doublescroll_readout = rpn.Ridge(ridge=5e-3)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"states = doublescroll_res.run(doublescroll_train_input)\n",
|
||||||
|
"doublescroll_readout1.fit(states, doublescroll_train_target)\n",
|
||||||
|
"output = doublescroll_readout.run(states)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"show(doublescroll_train_input, doublescroll_train_target, title=\"Doublescroll Raw Data & Noisy Data\")\n",
|
||||||
|
"show(output, doublescroll_train_target, title=\"Doublescroll RC Output\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"doublescroll_test_input = noisy_doublescroll_data[7000:]\n",
|
||||||
|
"doublescroll_test_target = doublescroll_data[7000:]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"states = doublescroll_res.run(doublescroll_test_input)\n",
|
||||||
|
"output = doublescroll_readout.run(states)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"show(output, doublescroll_test_target, title=\"Doublescroll RC Output\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"new_doublescroll_data = norm01(doublescroll(12000, h=0.1, x0=[0.1, 0.2, 0.3])[2000:,:])\n",
|
||||||
|
"noisy_new_doublescroll_data = add_noise(new_doublescroll_data, noise_type=\"gaussian\", noise_level=0.8)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"new_doublescroll_test_input = noisy_new_doublescroll_data\n",
|
||||||
|
"new_doublescroll_test_target = new_doublescroll_data"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"states = doublescroll_res.run(new_doublescroll_test_input)\n",
|
||||||
|
"output = doublescroll_readout.run(states)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"show(output, new_doublescroll_test_target, title=\"Doublescroll RC Output\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
||||||
548
.ipynb_checkpoints/denoise try-checkpoint.ipynb
Normal file
548
.ipynb_checkpoints/denoise try-checkpoint.ipynb
Normal file
File diff suppressed because one or more lines are too long
200
.ipynb_checkpoints/stand. noise show-checkpoint.ipynb
Normal file
200
.ipynb_checkpoints/stand. noise show-checkpoint.ipynb
Normal file
File diff suppressed because one or more lines are too long
1
ARNN
Submodule
1
ARNN
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 0d0ad4398b4de6835035a09999b9bbd4fe6d8587
|
||||||
849
denoise one test.ipynb
Normal file
849
denoise one test.ipynb
Normal file
File diff suppressed because one or more lines are too long
3
exp/data/chaos gen.py
Normal file
3
exp/data/chaos gen.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from reservoirpy.datasets import lorenz
|
||||||
|
from reservoirpy.datasets import kuramoto_sivashinsky
|
||||||
|
|
||||||
155
exp/data/scale_windspeed.txt
Normal file
155
exp/data/scale_windspeed.txt
Normal file
File diff suppressed because one or more lines are too long
22
exp/data/try wind.py
Normal file
22
exp/data/try wind.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import numpy as np
|
||||||
|
'''
|
||||||
|
cr. ARNN - Support Information
|
||||||
|
6.2. Wind speed dataset The wind speed dataset, which is provided by the Japan Meteorological Business Support Center,
|
||||||
|
contains the wind speed (m/s) time series sampled every ∆𝑡 = 10 minutes between 2010 and 2012 from 𝐷 = 155 wind stations (variables) in Wakkanai, Japan9.
|
||||||
|
As for the 155 stations, their specific locations (latitude and longitude) can be found in the original dataset file 201606241049longitudelatitude.mat accessible in https://github.com/RPcb/ARNN/tree/master/Data/wind%20speed .
|
||||||
|
We use 𝑚 = 110 time points as the known series and make predictions on the next 𝐿 − 1 = 45 time points. As shown in Figs. 3a-3b of the main text, the performance of ARNN is better than the other methods.
|
||||||
|
Besides, utilizing this dataset, we tested the robustness of ARNN with different prediction steps in Figs. 3c-3e of the main text, which proved the effectiveness of ARNN in any time region.
|
||||||
|
'''
|
||||||
|
# shape -> [155, 157819]
|
||||||
|
data = np.loadtxt('exp/data/scale_windspeed.txt')
|
||||||
|
print(data.shape)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
plt.figure(figsize=(12, 6))
|
||||||
|
plt.plot(data[0,:])
|
||||||
|
plt.title("Wind Speed Data")
|
||||||
|
plt.xlabel("Time (dt = 10 mins)")
|
||||||
|
plt.ylabel("Wind Speed (m/s)")
|
||||||
|
plt.grid(True)
|
||||||
|
plt.show()
|
||||||
346
exp/exp1.ipynb
Normal file
346
exp/exp1.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -23,7 +23,7 @@ import reservoirpy.nodes as rpn
|
|||||||
rpy.verbosity(0)
|
rpy.verbosity(0)
|
||||||
rpy.set_seed(42)
|
rpy.set_seed(42)
|
||||||
|
|
||||||
from tools import load_data, visualize_data_diff, visualize_psd_diff
|
from tools import visualize_data_diff, visualize_psd_diff, get_wind, get_kuramoto_sivashinsky, get_lorenz
|
||||||
|
|
||||||
def build_esn_model(units, lr, sr, reg, output_dim, reservoir_node=None):
|
def build_esn_model(units, lr, sr, reg, output_dim, reservoir_node=None):
|
||||||
|
|
||||||
@ -59,9 +59,8 @@ def evaluate_results(model, clean_data, noisy_data, warmup=0, vis_data=True, vis
|
|||||||
|
|
||||||
def run():
|
def run():
|
||||||
# 数据加载与预处理
|
# 数据加载与预处理
|
||||||
clean_data, noisy_data = load_data(system='lorenz', noise='gaussian',
|
clean_data, noisy_data = get_lorenz(noise='gaussian', intensity=0.1)
|
||||||
intensity=0.5, init=[1, 1, 1],
|
|
||||||
n_timesteps=40000, transient=10000, h=0.01)
|
|
||||||
# 分为训练集和测试集 按百分比
|
# 分为训练集和测试集 按百分比
|
||||||
train_size = int(len(clean_data) * 0.8)
|
train_size = int(len(clean_data) * 0.8)
|
||||||
train_clean_data = clean_data[:train_size]
|
train_clean_data = clean_data[:train_size]
|
||||||
|
|||||||
422
exp/tools.py
422
exp/tools.py
@ -80,7 +80,14 @@ def add_noise(data, noise_type='gaussian', intensity=0.1, **kwargs):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("未知的噪声类型。可选项:'gaussian', 'colored' (1/f), 'impulse', 'sine'")
|
raise ValueError("未知的噪声类型。可选项:'gaussian', 'colored' (1/f), 'impulse', 'sine'")
|
||||||
|
|
||||||
def load_data(system='lorenz', init='random', noise=None, intensity=0.1, h=0.01, n_timesteps=10000, transient=1000, normlization=True, **kwargs):
|
def load_data(system='lorenz',
|
||||||
|
init='random',
|
||||||
|
noise=None,
|
||||||
|
intensity=0.1,
|
||||||
|
h=0.01,
|
||||||
|
n_timesteps=10000,
|
||||||
|
transient=1000,
|
||||||
|
normlization=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
加载混沌系统数据.
|
加载混沌系统数据.
|
||||||
|
|
||||||
@ -100,17 +107,15 @@ def load_data(system='lorenz', init='random', noise=None, intensity=0.1, h=0.01,
|
|||||||
|
|
||||||
各系统默认值:
|
各系统默认值:
|
||||||
- lorenz: rho=28, sigma=10, beta=8/3, h=0.03, x0=[1, 1, 1]
|
- lorenz: rho=28, sigma=10, beta=8/3, h=0.03, x0=[1, 1, 1]
|
||||||
- rossler: a=0.2, b=0.2, c=5.7, h=0.1, x0=[-0.1, 0, 0.02]
|
|
||||||
- multiscroll: a=40, b=3, c=28, h=0.01, x0=[-0.1, 0.5, -0.6]
|
|
||||||
- kuramoto_sivashinsky: N=128, M=16, h=0.25, x0=None
|
- kuramoto_sivashinsky: N=128, M=16, h=0.25, x0=None
|
||||||
"""
|
"""
|
||||||
system = system.lower()
|
system = system.lower()
|
||||||
if init == 'random':
|
if init == 'random':
|
||||||
if system in ['lorenz', 'rossler', 'multiscroll']:
|
if system in ['lorenz']:
|
||||||
# 默认3维初始值
|
# 默认3维初始值
|
||||||
state = np.random.rand(3)
|
state = np.random.rand(3)
|
||||||
elif system in ['kuramoto_sivashinsky']:
|
elif system in ['kuramoto_sivashinsky']:
|
||||||
state = np.random.rand(1)
|
state = np.random.rand(kwargs.get('N', 128))
|
||||||
else:
|
else:
|
||||||
raise ValueError("未知的混沌系统类型。")
|
raise ValueError("未知的混沌系统类型。")
|
||||||
else:
|
else:
|
||||||
@ -133,40 +138,6 @@ def load_data(system='lorenz', init='random', noise=None, intensity=0.1, h=0.01,
|
|||||||
beta = kwargs.get('beta', 8/3)
|
beta = kwargs.get('beta', 8/3)
|
||||||
clean_data = datasets.lorenz(n_timesteps=n_timesteps, h=h, sigma=sigma, rho=rho, beta=beta, x0=state)
|
clean_data = datasets.lorenz(n_timesteps=n_timesteps, h=h, sigma=sigma, rho=rho, beta=beta, x0=state)
|
||||||
|
|
||||||
elif system == 'rossler':
|
|
||||||
'''
|
|
||||||
(function) def rossler(
|
|
||||||
n_timesteps: int,
|
|
||||||
a: float = 0.2,
|
|
||||||
b: float = 0.2,
|
|
||||||
c: float = 5.7,
|
|
||||||
x0: list | ndarray = [-0.1, 0, 0.02],
|
|
||||||
h: float = 0.1,
|
|
||||||
**kwargs: Any
|
|
||||||
) -> ndarray
|
|
||||||
'''
|
|
||||||
a = kwargs.get('a', 0.2)
|
|
||||||
b = kwargs.get('b', 0.2)
|
|
||||||
c = kwargs.get('c', 5.7)
|
|
||||||
clean_data = datasets.rossler(n_timesteps=n_timesteps, h=h, a=a, b=b, c=c, x0=state)
|
|
||||||
|
|
||||||
elif system == 'multiscroll':
|
|
||||||
'''
|
|
||||||
(function) def multiscroll(
|
|
||||||
n_timesteps: int,
|
|
||||||
a: float = 40,
|
|
||||||
b: float = 3,
|
|
||||||
c: float = 28,
|
|
||||||
x0: list | ndarray = [-0.1, 0.5, -0.6],
|
|
||||||
h: float = 0.01,
|
|
||||||
**kwargs: Any
|
|
||||||
) -> ndarray
|
|
||||||
'''
|
|
||||||
a = kwargs.get('a', 40)
|
|
||||||
b = kwargs.get('b', 3)
|
|
||||||
c = kwargs.get('c', 28)
|
|
||||||
clean_data = datasets.multiscroll(n_timesteps=n_timesteps, h=h, x0=state, a=a, b=b, c=c)
|
|
||||||
|
|
||||||
elif system == 'kuramoto_sivashinsky':
|
elif system == 'kuramoto_sivashinsky':
|
||||||
'''
|
'''
|
||||||
(function) def kuramoto_sivashinsky(
|
(function) def kuramoto_sivashinsky(
|
||||||
@ -181,11 +152,16 @@ def load_data(system='lorenz', init='random', noise=None, intensity=0.1, h=0.01,
|
|||||||
clean_data = datasets.kuramoto_sivashinsky(n_timesteps=n_timesteps, h=h, **kwargs)
|
clean_data = datasets.kuramoto_sivashinsky(n_timesteps=n_timesteps, h=h, **kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("未知的混沌系统类型。可选项: 'lorenz', 'rossler', 'multiscroll', 'kuramoto_sivashinsky'")
|
raise ValueError("未知的混沌系统类型。可选项: 'lorenz', 'kuramoto_sivashinsky'")
|
||||||
|
|
||||||
|
eps = 1e-8 # 定义一个极小值
|
||||||
|
normalize_01_eps = lambda data: (data - data.min(axis=0)) / (data.max(axis=0) - data.min(axis=0) + eps)
|
||||||
|
|
||||||
if normlization:
|
if normlization:
|
||||||
clean_data = (clean_data - clean_data.mean(axis=0)) / clean_data.std(axis=0) # Z-score 归一化
|
clean_data = normalize_01_eps(clean_data)
|
||||||
clean_data = clean_data[transient:]
|
print('use new 0-1 norm')
|
||||||
|
#clean_data = (clean_data - clean_data.mean(axis=0)) / clean_data.std(axis=0) # Z-score 归一化
|
||||||
|
clean_data = clean_data[transient:,:]
|
||||||
|
|
||||||
# 添加噪声
|
# 添加噪声
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
@ -204,34 +180,360 @@ def load_data(system='lorenz', init='random', noise=None, intensity=0.1, h=0.01,
|
|||||||
|
|
||||||
return clean_data, noisy_data
|
return clean_data, noisy_data
|
||||||
|
|
||||||
|
def get_lorenz(n_timesteps=10000, h=0.01, transient=1000,
|
||||||
|
noise='gaussian', intensity=0.1, x0=[1, 1, 1],
|
||||||
|
rho=28, sigma=10, beta=8/3, normlization=True):
|
||||||
|
"""
|
||||||
|
生成 Lorenz 系统数据.
|
||||||
|
|
||||||
|
参数:
|
||||||
|
n_timesteps: int, 时间步数.
|
||||||
|
h: float, 时间步长.
|
||||||
|
transient: int, 丢弃的时间步数.
|
||||||
|
noise: str, 噪声类型.
|
||||||
|
intensity: float, 噪声强度.
|
||||||
|
x0: list, 初始值.
|
||||||
|
rho, sigma, beta: float, 系统参数.
|
||||||
|
kwargs: 其他参数.
|
||||||
|
返回:
|
||||||
|
(clean_data, noisy_data): 两个 numpy 数组, 分别为干净数据和添加噪声后的数据.
|
||||||
|
"""
|
||||||
|
# 生成 Lorenz 系统数据
|
||||||
|
clean_data, noisy_data = load_data(system='lorenz', noise=noise, intensity=intensity, n_timesteps=n_timesteps, transient=transient, h=h, x0=x0, rho=rho, sigma=sigma, beta=beta, normlization=True)
|
||||||
|
return clean_data, noisy_data
|
||||||
|
|
||||||
|
def get_kuramoto_sivashinsky(n_timesteps=10000, h=0.25, transient=2000,
|
||||||
|
noise='gaussian', intensity=0.1, N=128, M=16,
|
||||||
|
x0=None, normlization=True):
|
||||||
|
"""
|
||||||
|
生成 Kuramoto-Sivashinsky 系统数据.
|
||||||
|
|
||||||
|
参数:
|
||||||
|
n_timesteps: int, 时间步数.
|
||||||
|
h: float, 时间步长.
|
||||||
|
transient: int, 丢弃的时间步数.
|
||||||
|
noise: str, 噪声类型.
|
||||||
|
intensity: float, 噪声强度.
|
||||||
|
N: int, 系统参数.
|
||||||
|
M: float, 系统参数.
|
||||||
|
x0: list | ndarray, 初始值.
|
||||||
|
返回:
|
||||||
|
(clean_data, noisy_data): 两个 numpy 数组, 分别为干净数据和添加噪声后的数据.
|
||||||
|
"""
|
||||||
|
clean_data, noisy_data = load_data(system='kuramoto_sivashinsky', noise=noise, intensity=intensity, n_timesteps=n_timesteps, transient=transient, h=h, N=N, M=M, x0=x0, normlization=True)
|
||||||
|
return clean_data, noisy_data
|
||||||
|
|
||||||
|
def get_wind(noise='gaussian', intensity=0.1, filepath='exp/data/scale_windspeed.txt', normlization=True):
|
||||||
|
'''
|
||||||
|
读取风速数据并添加噪声.
|
||||||
|
|
||||||
|
参数:
|
||||||
|
noise: str, 噪声类型.
|
||||||
|
intensity: float, 噪声强度.
|
||||||
|
filepath: str, 数据文件路径.
|
||||||
|
返回:
|
||||||
|
wind_data: numpy 数组, 添加噪声后的风速数据.
|
||||||
|
'''
|
||||||
|
wind_data = np.loadtxt(filepath).T # shape -> [155, 157819]
|
||||||
|
|
||||||
|
if normlization:
|
||||||
|
wind_data = (wind_data - wind_data.min(axis=0)) / (wind_data.max(axis=0) - wind_data.min(axis=0) + 1e-8)
|
||||||
|
# wind_data = (wind_data - wind_data.mean(axis=0)) / wind_data.std(axis=0) # Z-score 归一化
|
||||||
|
|
||||||
|
if noise is not None:
|
||||||
|
# 添加噪声
|
||||||
|
wind_noisy = add_noise(wind_data, noise_type=noise, intensity=intensity)
|
||||||
|
|
||||||
|
return wind_data, wind_noisy
|
||||||
|
|
||||||
def visualize_data_diff(clean_data, noisy_data, prediction, warmup=0, title_prefix=""):
|
def visualize_data_diff(clean_data, noisy_data, prediction, warmup=0, title_prefix=""):
|
||||||
plt.figure(figsize=(12, 6))
|
"""
|
||||||
for i in range(3):
|
可视化数据比较:噪声数据与干净数据,预测数据与干净数据
|
||||||
plt.subplot(3, 2, 2*i+1)
|
|
||||||
plt.plot(noisy_data[warmup:, i], label='noisy')
|
参数:
|
||||||
plt.plot(clean_data[warmup:, i], label='clean')
|
clean_data: 干净数据
|
||||||
|
noisy_data: 噪声数据
|
||||||
|
prediction: 预测数据
|
||||||
|
warmup: 起始索引
|
||||||
|
title_prefix: 标题前缀
|
||||||
|
"""
|
||||||
|
# 获取数据维度
|
||||||
|
n_dims = 1 if clean_data.ndim == 1 else clean_data.shape[1]
|
||||||
|
|
||||||
|
# 调整图形大小和布局
|
||||||
|
if n_dims <= 3:
|
||||||
|
fig_height = 4 * n_dims
|
||||||
|
fig_width = 12
|
||||||
|
n_cols = 2
|
||||||
|
else:
|
||||||
|
fig_height = 3 * ((n_dims + 1) // 2)
|
||||||
|
fig_width = 14
|
||||||
|
n_cols = 4
|
||||||
|
|
||||||
|
plt.figure(figsize=(fig_width, fig_height))
|
||||||
|
|
||||||
|
for i in range(n_dims):
|
||||||
|
# 获取当前维度的数据
|
||||||
|
if clean_data.ndim == 1:
|
||||||
|
clean_dim = clean_data[warmup:]
|
||||||
|
noisy_dim = noisy_data[warmup:]
|
||||||
|
pred_dim = prediction[warmup:]
|
||||||
|
else:
|
||||||
|
clean_dim = clean_data[warmup:, i]
|
||||||
|
noisy_dim = noisy_data[warmup:, i]
|
||||||
|
pred_dim = prediction[warmup:, i]
|
||||||
|
|
||||||
|
# 噪声数据与干净数据对比
|
||||||
|
plt.subplot(n_dims, n_cols, n_cols*i+1)
|
||||||
|
plt.plot(noisy_dim, label='noisy')
|
||||||
|
plt.plot(clean_dim, label='clean')
|
||||||
plt.title(f"{title_prefix} Dim {i+1} (Noisy vs Clean)")
|
plt.title(f"{title_prefix} Dim {i+1} (Noisy vs Clean)")
|
||||||
plt.legend()
|
if i == 0 or n_dims <= 3:
|
||||||
plt.subplot(3, 2, 2*i+2)
|
plt.legend()
|
||||||
plt.plot(prediction[warmup:, i], label='prediction')
|
|
||||||
plt.plot(clean_data[warmup:, i], label='clean')
|
# 预测数据与干净数据对比
|
||||||
|
plt.subplot(n_dims, n_cols, n_cols*i+2)
|
||||||
|
plt.plot(pred_dim, label='prediction')
|
||||||
|
plt.plot(clean_dim, label='clean')
|
||||||
plt.title(f"{title_prefix} Dim {i+1} (Prediction vs Clean)")
|
plt.title(f"{title_prefix} Dim {i+1} (Prediction vs Clean)")
|
||||||
plt.legend()
|
if i == 0 or n_dims <= 3:
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
def visualize_psd_diff(clean_data, noisy_data, prediction, warmup=0, title_prefix=""):
|
def visualize_psd_diff(clean_data, noisy_data, prediction, warmup=0, title_prefix=""):
|
||||||
plt.figure(figsize=(12, 6))
|
"""
|
||||||
for i in range(3):
|
可视化数据的功率谱密度比较
|
||||||
plt.subplot(3, 1, i+1)
|
|
||||||
plt.psd(clean_data[warmup:, i], NFFT=1024, label='clean')
|
参数:
|
||||||
plt.psd(noisy_data[warmup:, i], NFFT=1024, label='noisy')
|
clean_data: 干净数据
|
||||||
plt.psd(prediction[warmup:, i], NFFT=1024, label='prediction')
|
noisy_data: 噪声数据
|
||||||
|
prediction: 预测数据
|
||||||
|
warmup: 起始索引
|
||||||
|
title_prefix: 标题前缀
|
||||||
|
"""
|
||||||
|
# 获取数据维度
|
||||||
|
n_dims = 1 if clean_data.ndim == 1 else clean_data.shape[1]
|
||||||
|
|
||||||
|
# 调整图形大小
|
||||||
|
fig_height = 4 * min(n_dims, 5) # 限制最大高度
|
||||||
|
fig_width = 12
|
||||||
|
|
||||||
|
plt.figure(figsize=(fig_width, fig_height))
|
||||||
|
|
||||||
|
# 如果维度过多,只显示前5个维度
|
||||||
|
show_dims = min(n_dims, 5)
|
||||||
|
|
||||||
|
for i in range(show_dims):
|
||||||
|
plt.subplot(show_dims, 1, i+1)
|
||||||
|
|
||||||
|
# 获取当前维度的数据
|
||||||
|
if clean_data.ndim == 1:
|
||||||
|
clean_dim = clean_data[warmup:]
|
||||||
|
noisy_dim = noisy_data[warmup:]
|
||||||
|
pred_dim = prediction[warmup:]
|
||||||
|
else:
|
||||||
|
clean_dim = clean_data[warmup:, i]
|
||||||
|
noisy_dim = noisy_data[warmup:, i]
|
||||||
|
pred_dim = prediction[warmup:, i]
|
||||||
|
|
||||||
|
plt.psd(clean_dim, NFFT=1024, label='clean')
|
||||||
|
plt.psd(noisy_dim, NFFT=1024, label='noisy')
|
||||||
|
plt.psd(pred_dim, NFFT=1024, label='prediction')
|
||||||
plt.title(f"{title_prefix} Dim {i+1} PSD Comparison")
|
plt.title(f"{title_prefix} Dim {i+1} PSD Comparison")
|
||||||
plt.legend()
|
plt.legend()
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
|
def visualize_ks_denoising(clean_data, noisy_data, prediction, warmup=0, title_prefix="", time_slice=None):
|
||||||
|
"""
|
||||||
|
可视化Kuramoto-Sivashinsky系统的降噪结果 (修改版)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
clean_data: 干净数据,形状为[时间步数, 空间点数]
|
||||||
|
noisy_data: 噪声数据,形状同上
|
||||||
|
prediction: 预测数据,形状同上
|
||||||
|
warmup: 要从数据开头移除的时间步数
|
||||||
|
title_prefix: 图形总标题的前缀
|
||||||
|
time_slice: 要显示的特定时间索引列表 (最多3个),如果为None则自动选择
|
||||||
|
"""
|
||||||
|
# 1. 数据准备
|
||||||
|
if warmup >= clean_data.shape[0]:
|
||||||
|
print(f"Warning: Warmup period ({warmup}) is longer than or equal to data length ({clean_data.shape[0]}). No data to plot.")
|
||||||
|
return
|
||||||
|
|
||||||
|
clean = clean_data[warmup:]
|
||||||
|
noisy = noisy_data[warmup:]
|
||||||
|
pred = prediction[warmup:]
|
||||||
|
|
||||||
|
if clean.shape[0] == 0:
|
||||||
|
print("Warning: Data is empty after removing warmup period.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 计算误差和MSE
|
||||||
|
mse_noisy = np.mean((clean - noisy) ** 2)
|
||||||
|
mse_pred = np.mean((clean - pred) ** 2)
|
||||||
|
diff_noisy = clean - noisy
|
||||||
|
diff_pred = clean - pred
|
||||||
|
|
||||||
|
# 为差异图确定对称的颜色范围
|
||||||
|
max_abs_diff = max(np.abs(diff_noisy).max(), np.abs(diff_pred).max())
|
||||||
|
# 防止 max_abs_diff 为 0 导致 vmin=vmax
|
||||||
|
if max_abs_diff < 1e-9:
|
||||||
|
max_abs_diff = 1e-9
|
||||||
|
diff_vmin, diff_vmax = -max_abs_diff, max_abs_diff
|
||||||
|
|
||||||
|
# 为数据图确定共享的颜色范围 (可选,但有助于比较)
|
||||||
|
data_vmin = min(clean.min(), noisy.min(), pred.min())
|
||||||
|
data_vmax = max(clean.max(), noisy.max(), pred.max())
|
||||||
|
if data_vmax - data_vmin < 1e-9: # Handle constant data case
|
||||||
|
data_vmax = data_vmin + 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# 2. 创建图形和轴 (4x3 网格)
|
||||||
|
fig, axes = plt.subplots(4, 3, figsize=(15, 17)) # 调整 figsize 以适应 4 行
|
||||||
|
fig.suptitle(f"{title_prefix} Kuramoto-Sivashinsky Denoising Results", fontsize=16)
|
||||||
|
|
||||||
|
# imshow 的通用范围
|
||||||
|
extent = [0, clean.shape[0], 0, clean.shape[1]] # [时间_min, 时间_max, 空间_min, 空间_max]
|
||||||
|
time_label = 'Time Step (after warmup)'
|
||||||
|
space_label = 'Space'
|
||||||
|
amplitude_label = 'Amplitude'
|
||||||
|
difference_label = 'Difference'
|
||||||
|
|
||||||
|
# --- 第 1 行: 干净, 噪声, 干净 - 噪声 ---
|
||||||
|
im1 = axes[0, 0].imshow(clean.T, aspect='auto', cmap='viridis', extent=extent, vmin=data_vmin, vmax=data_vmax)
|
||||||
|
fig.colorbar(im1, ax=axes[0, 0], label=amplitude_label)
|
||||||
|
axes[0, 0].set_title("Clean Data")
|
||||||
|
axes[0, 0].set_ylabel(space_label)
|
||||||
|
|
||||||
|
im2 = axes[0, 1].imshow(noisy.T, aspect='auto', cmap='viridis', extent=extent, vmin=data_vmin, vmax=data_vmax)
|
||||||
|
fig.colorbar(im2, ax=axes[0, 1], label=amplitude_label)
|
||||||
|
axes[0, 1].set_title(f"Noisy Data (MSE: {mse_noisy:.4f})")
|
||||||
|
# axes[0, 1].set_ylabel(space_label) # Y轴标签通常只在最左侧显示
|
||||||
|
|
||||||
|
im3 = axes[0, 2].imshow(diff_noisy.T, aspect='auto', cmap='coolwarm', extent=extent, vmin=diff_vmin, vmax=diff_vmax)
|
||||||
|
fig.colorbar(im3, ax=axes[0, 2], label=difference_label)
|
||||||
|
axes[0, 2].set_title("Clean - Noisy")
|
||||||
|
# axes[0, 2].set_ylabel(space_label)
|
||||||
|
|
||||||
|
# --- 第 2 行: 干净, 预测, 干净 - 预测 ---
|
||||||
|
im4 = axes[1, 0].imshow(clean.T, aspect='auto', cmap='viridis', extent=extent, vmin=data_vmin, vmax=data_vmax)
|
||||||
|
fig.colorbar(im4, ax=axes[1, 0], label=amplitude_label)
|
||||||
|
axes[1, 0].set_title("Clean Data")
|
||||||
|
axes[1, 0].set_ylabel(space_label)
|
||||||
|
axes[1, 0].set_xlabel(time_label) # 在底部imshow行添加时间标签
|
||||||
|
|
||||||
|
im5 = axes[1, 1].imshow(pred.T, aspect='auto', cmap='viridis', extent=extent, vmin=data_vmin, vmax=data_vmax)
|
||||||
|
fig.colorbar(im5, ax=axes[1, 1], label=amplitude_label)
|
||||||
|
axes[1, 1].set_title(f"Prediction (MSE: {mse_pred:.4f})")
|
||||||
|
# axes[1, 1].set_ylabel(space_label)
|
||||||
|
axes[1, 1].set_xlabel(time_label)
|
||||||
|
|
||||||
|
im6 = axes[1, 2].imshow(diff_pred.T, aspect='auto', cmap='coolwarm', extent=extent, vmin=diff_vmin, vmax=diff_vmax)
|
||||||
|
fig.colorbar(im6, ax=axes[1, 2], label=difference_label)
|
||||||
|
axes[1, 2].set_title("Clean - Prediction")
|
||||||
|
# axes[1, 2].set_ylabel(space_label)
|
||||||
|
axes[1, 2].set_xlabel(time_label)
|
||||||
|
|
||||||
|
# --- 第 3 行: 时间切片 ---
|
||||||
|
num_steps = clean.shape[0]
|
||||||
|
if time_slice is None:
|
||||||
|
# 自动选择最多3个均匀分布的时间点
|
||||||
|
if num_steps > 0:
|
||||||
|
# 确保 t_points 是有效的索引,并且尽可能均匀
|
||||||
|
indices = np.linspace(0, num_steps - 1, 5, dtype=int)[1:-1] # 取中间的3个点
|
||||||
|
t_points = sorted(list(set(indices))) # 去重并排序
|
||||||
|
# 如果点数少于3个(因为数据短或linspace结果重复),补充点
|
||||||
|
while len(t_points) < 3 and len(t_points) < num_steps:
|
||||||
|
if 0 not in t_points: t_points.insert(0, 0)
|
||||||
|
elif num_steps - 1 not in t_points: t_points.append(num_steps - 1)
|
||||||
|
else: # 如果头尾都有了,尝试在中间加
|
||||||
|
mid_ish = num_steps // 2
|
||||||
|
if mid_ish not in t_points: t_points.append(mid_ish)
|
||||||
|
else: break # 无法添加更多唯一的点
|
||||||
|
t_points = sorted(list(set(t_points)))[:3] # 保持最多3个
|
||||||
|
else:
|
||||||
|
t_points = [] # 如果没有时间步,则没有切片
|
||||||
|
else:
|
||||||
|
# 过滤用户提供的时间切片以确保有效性,并最多取前3个
|
||||||
|
t_points = [t for t in time_slice if 0 <= t < num_steps]
|
||||||
|
if len(t_points) > 3:
|
||||||
|
print(f"Warning: Provided {len(time_slice)} time slices. Using the first 3 valid ones: {t_points[:3]}")
|
||||||
|
t_points = t_points[:3]
|
||||||
|
elif not t_points and time_slice:
|
||||||
|
print(f"Warning: None of the provided time slices {time_slice} are valid for data length {num_steps}.")
|
||||||
|
|
||||||
|
|
||||||
|
# 绘制时间切片图
|
||||||
|
num_slice_plots = 3
|
||||||
|
for i in range(num_slice_plots):
|
||||||
|
ax_slice = axes[2, i]
|
||||||
|
if i < len(t_points):
|
||||||
|
t = t_points[i]
|
||||||
|
ax_slice.plot(clean[t], label='Clean')
|
||||||
|
ax_slice.plot(noisy[t], label='Noisy', alpha=0.7)
|
||||||
|
ax_slice.plot(pred[t], label='Prediction', linestyle='--')
|
||||||
|
ax_slice.set_title(f"Time slice at t={t}")
|
||||||
|
ax_slice.set_xlabel(space_label)
|
||||||
|
if i == 0: ax_slice.set_ylabel(amplitude_label) # 只在最左侧显示Y轴标签
|
||||||
|
ax_slice.legend()
|
||||||
|
ax_slice.grid(True, linestyle='--', alpha=0.6)
|
||||||
|
else:
|
||||||
|
# 如果没有足够的时间点来绘制,则隐藏多余的子图
|
||||||
|
ax_slice.axis('off')
|
||||||
|
|
||||||
|
|
||||||
|
# --- 第 4 行: 误差分布 & 随时间变化的 MSE ---
|
||||||
|
# 误差直方图
|
||||||
|
bins = 50
|
||||||
|
ax_hist_noisy = axes[3, 0]
|
||||||
|
counts_noisy, _, _ = ax_hist_noisy.hist(diff_noisy.flatten(), bins=bins, alpha=0.7, label='Noisy Error')
|
||||||
|
ax_hist_noisy.set_title('Noise Error Distribution')
|
||||||
|
ax_hist_noisy.set_xlabel('Error (Clean - Noisy)')
|
||||||
|
ax_hist_noisy.set_ylabel('Count')
|
||||||
|
ax_hist_noisy.grid(True, linestyle='--', alpha=0.6)
|
||||||
|
|
||||||
|
ax_hist_pred = axes[3, 1]
|
||||||
|
counts_pred, _, _ = ax_hist_pred.hist(diff_pred.flatten(), bins=bins, alpha=0.7, label='Prediction Error', color='orange')
|
||||||
|
ax_hist_pred.set_title('Prediction Error Distribution')
|
||||||
|
ax_hist_pred.set_xlabel('Error (Clean - Prediction)')
|
||||||
|
# ax_hist_pred.set_ylabel('Count') # Y轴标签通常只在最左侧显示
|
||||||
|
|
||||||
|
# 为直方图设置共享的 y 轴限制
|
||||||
|
max_y = 0
|
||||||
|
if counts_noisy.size > 0: # 检查是否有计数
|
||||||
|
max_y = max(max_y, np.max(counts_noisy))
|
||||||
|
if counts_pred.size > 0: # 检查是否有计数
|
||||||
|
max_y = max(max_y, np.max(counts_pred))
|
||||||
|
|
||||||
|
if max_y > 0: # 仅当绘制了直方图时才设置 ylim
|
||||||
|
common_ylim = (0, max_y * 1.1) # 增加 10% 的空隙
|
||||||
|
ax_hist_noisy.set_ylim(common_ylim)
|
||||||
|
ax_hist_pred.set_ylim(common_ylim)
|
||||||
|
else: # 处理数据为空或恒定的情况
|
||||||
|
ax_hist_noisy.set_ylim(0, 1)
|
||||||
|
ax_hist_pred.set_ylim(0, 1)
|
||||||
|
|
||||||
|
# 随时间变化的 MSE
|
||||||
|
ax_mse = axes[3, 2]
|
||||||
|
time_axis = np.arange(num_steps) # 创建时间轴
|
||||||
|
ax_mse.plot(time_axis, np.mean(diff_noisy**2, axis=1), label='Noisy MSE')
|
||||||
|
ax_mse.plot(time_axis, np.mean(diff_pred**2, axis=1), label='Prediction MSE')
|
||||||
|
ax_mse.set_title('MSE over Time')
|
||||||
|
ax_mse.set_xlabel(time_label)
|
||||||
|
ax_mse.set_ylabel('MSE')
|
||||||
|
ax_mse.set_yscale('log') # MSE 通常最好用对数刻度显示
|
||||||
|
ax_mse.legend()
|
||||||
|
ax_mse.grid(True, linestyle='--', alpha=0.6)
|
||||||
|
|
||||||
|
# --- 最终调整 ---
|
||||||
|
plt.tight_layout(rect=[0, 0.01, 1, 0.96]) # 调整布局以适应 suptitle 和 x 标签
|
||||||
|
# plt.show() # 如果需要立即显示图形,取消注释此行
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 测试数据加载和噪声添加
|
# 测试数据加载和噪声添加
|
||||||
clean_data, noisy_data = load_data(system='lorenz', noise=['gaussian', 'colored'], intensity=0.1, n_timesteps=10000, transient=1000)
|
clean_data, noisy_data = get_kuramoto_sivashinsky()
|
||||||
print("Clean Data Shape:", clean_data.shape)
|
print("Clean Data Shape:", clean_data.shape)
|
||||||
print("Noisy Data Shape:", noisy_data.shape)
|
print("Noisy Data Shape:", noisy_data.shape)
|
||||||
|
visualize_ks_denoising(clean_data, noisy_data, noisy_data)
|
||||||
76
exp/try vis ks.ipynb
Normal file
76
exp/try vis ks.ipynb
Normal file
File diff suppressed because one or more lines are too long
1
server.bat
Normal file
1
server.bat
Normal file
@ -0,0 +1 @@
|
|||||||
|
jupyter-lab --ip="0.0.0.0"
|
||||||
Loading…
x
Reference in New Issue
Block a user