1. 基于深度学习的python速通(一)
  2. 基于深度学习的python速通(三)
  3. 基于深度学习的python速通(二)

Matplotlib数据可视化

Matplotlib是Python中最重要的数据可视化库之一,它提供了类似MATLAB的绘图接口,可以创建各种静态、动态和交互式的图表。在深度学习中,Matplotlib常用于可视化训练过程、数据分布、模型结果等。

基础导入与设置

导入Matplotlib时,通常使用pyplot模块,并给其一个简短的别名:

1
2
3
4
5
6
import matplotlib.pyplot as plt
import numpy as np

# 设置中文字体支持(可选)
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号

基本绘图流程

Matplotlib的基本绘图流程包括:

  1. 创建图形和坐标轴
  2. 绘制数据
  3. 添加标签和标题
  4. 显示或保存图形
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import matplotlib.pyplot as plt
import numpy as np

# 创建数据
x = np.linspace(0, 10, 100)
y = np.sin(x)

# 创建图形
plt.figure(figsize=(8, 6))
plt.plot(x, y)
plt.xlabel('X轴')
plt.ylabel('Y轴')
plt.title('正弦函数图')
plt.show()

高级图表类型

热力图

热力图适用于显示数据的密度分布或相关性矩阵。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import matplotlib.pyplot as plt
import numpy as np

# 创建示例数据
data = np.random.rand(10, 12)

plt.figure(figsize=(10, 8))
plt.imshow(data, cmap='viridis', aspect='auto')
plt.colorbar(label='数值')
plt.title('热力图示例')
plt.xlabel('X轴')
plt.ylabel('Y轴')
plt.show()

# 相关性矩阵热力图
import pandas as pd

# 创建示例数据框
df = pd.DataFrame(np.random.randn(100, 5),
columns=['A', 'B', 'C', 'D', 'E'])
correlation_matrix = df.corr()

plt.figure(figsize=(8, 6))
plt.imshow(correlation_matrix, cmap='coolwarm', vmin=-1, vmax=1)
plt.colorbar(label='相关系数')
plt.title('相关性矩阵热力图')
plt.xticks(range(len(correlation_matrix.columns)), correlation_matrix.columns)
plt.yticks(range(len(correlation_matrix.columns)), correlation_matrix.columns)
plt.show()

箱线图

箱线图用于显示数据的分布特征和异常值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import matplotlib.pyplot as plt
import numpy as np

# 生成示例数据
data1 = np.random.normal(100, 15, 200)
data2 = np.random.normal(80, 20, 200)
data3 = np.random.normal(90, 10, 200)
data4 = np.random.normal(70, 25, 200)

plt.figure(figsize=(10, 6))

# 创建箱线图
box_data = [data1, data2, data3, data4]
plt.boxplot(box_data, labels=['组A', '组B', '组C', '组D'])

plt.title('箱线图示例')
plt.ylabel('数值')
plt.grid(True, alpha=0.3)
plt.show()

# 水平箱线图
plt.figure(figsize=(10, 6))
plt.boxplot(box_data, labels=['组A', '组B', '组C', '组D'], vert=False)
plt.title('水平箱线图')
plt.xlabel('数值')
plt.grid(True, alpha=0.3)
plt.show()

小提琴图

小提琴图结合了箱线图和密度图的特点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import matplotlib.pyplot as plt
import numpy as np

# 生成示例数据
data1 = np.random.normal(100, 15, 200)
data2 = np.random.normal(80, 20, 200)
data3 = np.random.normal(90, 10, 200)

plt.figure(figsize=(10, 6))

# 创建小提琴图
violin_data = [data1, data2, data3]
parts = plt.violinplot(violin_data, positions=[1, 2, 3])

# 美化小提琴图
for pc in parts['bodies']:
pc.set_facecolor('lightblue')
pc.set_alpha(0.7)

plt.title('小提琴图示例')
plt.xlabel('组别')
plt.ylabel('数值')
plt.xticks([1, 2, 3], ['组A', '组B', '组C'])
plt.grid(True, alpha=0.3)
plt.show()

极坐标图

极坐标图适用于显示周期性数据或方向性数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import matplotlib.pyplot as plt
import numpy as np

# 创建极坐标数据
theta = np.linspace(0, 2*np.pi, 100)
r1 = 1 + 0.3*np.cos(5*theta)
r2 = 0.8 + 0.2*np.sin(3*theta)

# 创建极坐标子图
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5),
subplot_kw=dict(projection='polar'))

# 第一个极坐标图
ax1.plot(theta, r1, 'b-', linewidth=2)
ax1.set_title('极坐标图1')
ax1.grid(True)

# 第二个极坐标图
ax2.plot(theta, r2, 'r-', linewidth=2)
ax2.fill(theta, r2, alpha=0.3, color='red')
ax2.set_title('极坐标图2')
ax2.grid(True)

plt.tight_layout()
plt.show()

3D图形

Matplotlib也支持3D图形的绘制。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

# 创建3D数据
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))

# 创建3D图形
fig = plt.figure(figsize=(15, 5))

# 3D表面图
ax1 = fig.add_subplot(131, projection='3d')
ax1.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)
ax1.set_title('3D表面图')

# 3D线框图
ax2 = fig.add_subplot(132, projection='3d')
ax2.plot_wireframe(X, Y, Z, color='blue', alpha=0.6)
ax2.set_title('3D线框图')

# 3D散点图
ax3 = fig.add_subplot(133, projection='3d')
x_scatter = np.random.randn(100)
y_scatter = np.random.randn(100)
z_scatter = np.random.randn(100)
ax3.scatter(x_scatter, y_scatter, z_scatter, c=z_scatter, cmap='plasma')
ax3.set_title('3D散点图')

plt.tight_layout()
plt.show()

实用技巧与最佳实践

保存图形

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.figure(figsize=(10, 6))
plt.plot(x, y, linewidth=2)
plt.title('保存图形示例')
plt.xlabel('X轴')
plt.ylabel('Y轴')
plt.grid(True)

# 保存为不同格式
plt.savefig('图形.png', dpi=300, bbox_inches='tight')
plt.savefig('图形.pdf', bbox_inches='tight')
plt.savefig('图形.svg', bbox_inches='tight')

plt.show()

动画制作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

# 创建动画数据
fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(0, 2*np.pi, 100)
line, = ax.plot(x, np.sin(x))

ax.set_ylim(-2, 2)
ax.set_title('正弦波动画')
ax.grid(True)

def animate(frame):
line.set_ydata(np.sin(x + frame/10))
return line,

# 创建动画
anim = animation.FuncAnimation(fig, animate, frames=200,
interval=50, blit=True)

# 保存动画(需要安装ffmpeg)
# anim.save('animation.gif', writer='pillow')

plt.show()

总结

Matplotlib是Python中最强大的数据可视化库之一,本笔记涵盖了:

  1. 基础知识:导入、基本绘图流程
  2. 基础图表:线图、散点图、柱状图、直方图、饼图
  3. 图形美化:颜色、线型、标记、字体、样式主题
  4. 布局管理:子图、GridSpec、不规则布局
  5. 高级图表:热力图、箱线图、小提琴图、极坐标图、3D图形
  6. 实用技巧:保存图形、动画制作

掌握这些内容后,你就能够创建各种专业的数据可视化图表,为数据分析和机器学习项目提供强有力的支持。

图形美化与样式设置

颜色和线型设置

Matplotlib提供了丰富的颜色和线型选项来美化图表。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

plt.figure(figsize=(12, 8))

# 不同的颜色表示方法
plt.plot(x, np.sin(x), color='red', label='红色')
plt.plot(x, np.sin(x + 0.5), color='#FF5733', label='十六进制颜色')
plt.plot(x, np.sin(x + 1), color=(0.2, 0.8, 0.2), label='RGB元组')
plt.plot(x, np.sin(x + 1.5), color='c', label='单字母缩写')

# 不同的线型
plt.plot(x, np.sin(x + 2), linestyle='-', label='实线')
plt.plot(x, np.sin(x + 2.5), linestyle='--', label='虚线')
plt.plot(x, np.sin(x + 3), linestyle='-.', label='点划线')
plt.plot(x, np.sin(x + 3.5), linestyle=':', label='点线')

plt.xlabel('X轴')
plt.ylabel('Y轴')
plt.title('颜色和线型示例')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

标记样式设置

为数据点添加标记可以使图表更加清晰。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0, 10, 1)
y1 = x ** 2
y2 = x ** 1.5

plt.figure(figsize=(10, 6))

# 不同的标记样式
plt.plot(x, y1, marker='o', markersize=8, label='圆形标记')
plt.plot(x, y2, marker='s', markersize=8, label='方形标记')
plt.plot(x, y1 - 10, marker='^', markersize=8, label='三角形标记')
plt.plot(x, y2 - 10, marker='D', markersize=8, label='菱形标记')

plt.xlabel('X轴')
plt.ylabel('Y轴')
plt.title('标记样式示例')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

字体和文本设置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import matplotlib.pyplot as plt
import numpy as np

# 设置全局字体
plt.rcParams['font.size'] = 12
plt.rcParams['font.family'] = 'serif'

x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.figure(figsize=(10, 6))
plt.plot(x, y, linewidth=2, color='blue')

# 设置标题和标签的字体
plt.xlabel('X轴标签', fontsize=14, fontweight='bold')
plt.ylabel('Y轴标签', fontsize=14, fontweight='bold')
plt.title('字体设置示例', fontsize=16, fontweight='bold', color='darkred')

# 添加文本注释
plt.text(5, 0.5, '这是一个注释', fontsize=12,
bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5))

# 添加箭头注释
plt.annotate('最大值', xy=(np.pi/2, 1), xytext=(3, 0.8),
arrowprops=dict(arrowstyle='->', color='red', lw=2),
fontsize=12, color='red')

plt.grid(True, alpha=0.3)
plt.show()

图表样式主题

Matplotlib提供了多种预设的样式主题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import matplotlib.pyplot as plt
import numpy as np

# 查看可用样式
print("可用样式:", plt.style.available)

x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)

# 使用不同样式
styles = ['default', 'seaborn', 'ggplot', 'dark_background']

fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.ravel()

for i, style in enumerate(styles):
with plt.style.context(style):
axes[i].plot(x, y1, label='sin(x)')
axes[i].plot(x, y2, label='cos(x)')
axes[i].set_title(f'样式: {style}')
axes[i].legend()
axes[i].grid(True)

plt.tight_layout()
plt.show()

子图与布局管理

基础子图创建

使用subplot()函数可以在一个图形窗口中创建多个子图。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

# 创建2x2的子图布局
plt.figure(figsize=(12, 10))

# 第一个子图
plt.subplot(2, 2, 1)
plt.plot(x, np.sin(x), 'b-')
plt.title('sin(x)')
plt.grid(True)

# 第二个子图
plt.subplot(2, 2, 2)
plt.plot(x, np.cos(x), 'r-')
plt.title('cos(x)')
plt.grid(True)

# 第三个子图
plt.subplot(2, 2, 3)
plt.plot(x, np.tan(x), 'g-')
plt.title('tan(x)')
plt.ylim(-5, 5)
plt.grid(True)

# 第四个子图
plt.subplot(2, 2, 4)
plt.plot(x, np.exp(-x/5), 'm-')
plt.title('exp(-x/5)')
plt.grid(True)

plt.tight_layout()
plt.show()

使用subplots()函数

subplots()函数提供了更灵活的子图创建方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

# 创建子图
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# 绘制不同的图形
axes[0, 0].plot(x, np.sin(x))
axes[0, 0].set_title('sin(x)')

axes[0, 1].plot(x, np.cos(x), 'r')
axes[0, 1].set_title('cos(x)')

axes[0, 2].scatter(x[::10], np.sin(x[::10]), c='green')
axes[0, 2].set_title('散点图')

axes[1, 0].bar(range(10), np.random.rand(10))
axes[1, 0].set_title('柱状图')

axes[1, 1].hist(np.random.randn(1000), bins=30)
axes[1, 1].set_title('直方图')

axes[1, 2].pie([1, 2, 3, 4], labels=['A', 'B', 'C', 'D'])
axes[1, 2].set_title('饼图')

# 调整布局
plt.tight_layout()
plt.show()

不规则子图布局

使用subplot2grid()可以创建不规则的子图布局。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

plt.figure(figsize=(12, 8))

# 创建不规则布局
ax1 = plt.subplot2grid((3, 3), (0, 0), colspan=2)
ax1.plot(x, np.sin(x))
ax1.set_title('大图 - sin(x)')

ax2 = plt.subplot2grid((3, 3), (0, 2))
ax2.plot(x, np.cos(x), 'r')
ax2.set_title('cos(x)')

ax3 = plt.subplot2grid((3, 3), (1, 0))
ax3.scatter(x[::10], np.random.rand(len(x[::10])))
ax3.set_title('散点图')

ax4 = plt.subplot2grid((3, 3), (1, 1), colspan=2, rowspan=2)
ax4.plot(x, np.exp(-x/5))
ax4.set_title('大图 - exp(-x/5)')

ax5 = plt.subplot2grid((3, 3), (2, 0))
ax5.bar(range(5), np.random.rand(5))
ax5.set_title('柱状图')

plt.tight_layout()
plt.show()

GridSpec高级布局

GridSpec提供了最灵活的子图布局控制。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

x = np.linspace(0, 10, 100)

# 创建GridSpec
fig = plt.figure(figsize=(12, 8))
gs = gridspec.GridSpec(3, 4, figure=fig)

# 创建不同大小的子图
ax1 = fig.add_subplot(gs[0, :2])
ax1.plot(x, np.sin(x))
ax1.set_title('横跨两列')

ax2 = fig.add_subplot(gs[0, 2:])
ax2.plot(x, np.cos(x), 'r')
ax2.set_title('横跨两列')

ax3 = fig.add_subplot(gs[1:, 0])
ax3.plot(np.sin(x), x)
ax3.set_title('纵跨两行')

ax4 = fig.add_subplot(gs[1, 1:])
ax4.scatter(x[::5], np.random.rand(len(x[::5])))
ax4.set_title('横跨三列')

ax5 = fig.add_subplot(gs[2, 1])
ax5.bar(range(3), [1, 2, 3])
ax5.set_title('小图1')

ax6 = fig.add_subplot(gs[2, 2])
ax6.pie([1, 2, 3], labels=['A', 'B', 'C'])
ax6.set_title('小图2')

ax7 = fig.add_subplot(gs[2, 3])
ax7.hist(np.random.randn(100), bins=10)
ax7.set_title('小图3')

plt.tight_layout()
plt.show()

基础绘图功能

线图(Line Plot)

线图是最基本的图表类型,用于显示数据随时间或其他连续变量的变化趋势。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import matplotlib.pyplot as plt
import numpy as np

# 创建数据
x = np.linspace(0, 2*np.pi, 100)
y1 = np.sin(x)
y2 = np.cos(x)

# 绘制多条线
plt.figure(figsize=(10, 6))
plt.plot(x, y1, label='sin(x)', color='blue', linewidth=2)
plt.plot(x, y2, label='cos(x)', color='red', linewidth=2, linestyle='--')

plt.xlabel('X')
plt.ylabel('Y')
plt.title('三角函数图')
plt.legend() # 显示图例
plt.grid(True) # 显示网格
plt.show()

散点图(Scatter Plot)

散点图用于显示两个变量之间的关系,每个点代表一个数据样本。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import matplotlib.pyplot as plt
import numpy as np

# 创建随机数据
np.random.seed(42)
x = np.random.randn(100)
y = 2 * x + np.random.randn(100)

# 绘制散点图
plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=0.6, c='blue', s=50)
plt.xlabel('X值')
plt.ylabel('Y值')
plt.title('散点图示例')
plt.grid(True, alpha=0.3)
plt.show()

柱状图(Bar Plot)

柱状图用于比较不同类别的数值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import matplotlib.pyplot as plt
import numpy as np

# 创建数据
categories = ['A', 'B', 'C', 'D', 'E']
values = [23, 45, 56, 78, 32]

# 绘制柱状图
plt.figure(figsize=(8, 6))
bars = plt.bar(categories, values, color=['red', 'green', 'blue', 'orange', 'purple'])
plt.xlabel('类别')
plt.ylabel('数值')
plt.title('柱状图示例')

# 在柱子上显示数值
for bar, value in zip(bars, values):
plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
str(value), ha='center', va='bottom')

plt.show()

直方图(Histogram)

直方图用于显示数据的分布情况。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import matplotlib.pyplot as plt
import numpy as np

# 创建正态分布数据
np.random.seed(42)
data = np.random.normal(100, 15, 1000)

# 绘制直方图
plt.figure(figsize=(8, 6))
plt.hist(data, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
plt.xlabel('数值')
plt.ylabel('频率')
plt.title('正态分布直方图')
plt.grid(True, alpha=0.3)
plt.show()

饼图(Pie Chart)

饼图用于显示各部分占整体的比例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import matplotlib.pyplot as plt

# 创建数据
labels = ['Python', 'Java', 'JavaScript', 'C++', '其他']
sizes = [35, 25, 20, 15, 5]
colors = ['gold', 'lightcoral', 'lightskyblue', 'lightgreen', 'pink']
explode = (0.1, 0, 0, 0, 0) # 突出显示第一个扇形

# 绘制饼图
plt.figure(figsize=(8, 8))
plt.pie(sizes, explode=explode, labels=labels, colors=colors,
autopct='%1.1f%%', shadow=True, startangle=90)
plt.title('编程语言使用比例')
plt.axis('equal') # 确保饼图是圆形的
plt.show()