阅读 9

聚类获取图片的主色

聚类获取图片的主色

问题

我们想知道一张图片的主要颜色是什么?一般可以用RGB颜色来表示。但是这次我们不能用简单的基于像素颜色值统计的方法,因为那样在人类视觉上有一致含义的颜色无法归到一起,比如差异很小的绿色被分开了。

解决方案

这次我们用聚类的方法来做。我们把像素点放到RGB 3维空间中,颜色相近的点会聚成一团,我们只要找到大的聚类,也就知道了这张图片的主要颜色模式。具体的聚类算法我们选用最常见的KMeans

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.cluster import KMeans
复制代码
def pixels_cluster_main_color(pixels, k):
    clt = KMeans(n_clusters = k)
    clt.fit(pixels)
    main_colors = clt.cluster_centers_ # 聚类中心即为主要颜色
    labels, counts = np.unique(clt.labels_, return_counts=True) # 统计各聚类的像素数量
    main_rates = counts / len(pixels)
    return main_colors, main_rates

def cluster_main_color(rgb_tensor, k=5):
    assert len(rgb_tensor.shape) == 3 
    assert rgb_tensor.shape[2] == 3
    pixels = rgb_tensor.reshape((-1, 3))
    return pixels_cluster_main_color(pixels, k)

def show_colors(main_colors, main_rates):
    index = 0
    color_tensor = np.zeros((200, 1000, 3), dtype=np.uint8)
    for color, rate in zip(main_colors, main_rates):
        print(f"color rgb = {color}, rate = {rate*100:.2f}%")
        color_width = int(1000 * rate)
        color_tensor[:,index:(index+color_width)] = color
        index += color_width
    plt.axis('off')
    plt.imshow(color_tensor)
    plt.show()
复制代码

我们来测试香蕉和老虎的图片:

def get_rgb_tensor(path):
    img = Image.open(path)
    img = img.convert("RGB") # 通过转换确保图片为RGB格式
    img_tensor = np.asarray(img)
    return img_tensor

def show_img_tensor(img_tensor):
    plt.axis('off')
    plt.imshow(img_tensor)
复制代码
banana_rgb = get_rgb_tensor("assets/banana.jpg")
show_img_tensor(banana_rgb)
复制代码

banana

show_colors(*cluster_main_color(banana_rgb))
复制代码

结果如下,分辨出了几种绿色

color rgb = [168.47056118 168.81861392 22.43038945], rate = 7.53%

color rgb = [ 92.77975804 132.86910682 53.76924582], rate = 56.15%

color rgb = [231.82004601 242.5476643 216.33520056], rate = 5.79%

color rgb = [140.45579287 170.2376754 95.27551844], rate = 18.71%

color rgb = [50.74796788 84.55337381 23.26324552], rate = 11.81%

banana_color

tiger_rgb = get_rgb_tensor("assets/tiger.jpg")
show_img_tensor(tiger_rgb)
复制代码

tiger

show_colors(*cluster_main_color(tiger_rgb))
复制代码

比较好的分辨出了老虎的颜色

color rgb = [193.21527217 151.84653465 110.92006969], rate = 15.35%

color rgb = [20.47013864 16.90836847 2.74709892], rate = 26.79%

color rgb = [137.97838774 103.32118798 64.97793748], rate = 18.76%

color rgb = [239.12869368 207.88085138 175.984823 ], rate = 8.94%

color rgb = [80.86854703 75.73760081 16.58705309], rate = 30.16%

tiger_color

讨论

我们再尝试一张图片,这是一张带有透明通道的png图片,背景所见的白色其实是透明的。

orange_rgb = get_rgb_tensor("assets/orange.png")
show_img_tensor(orange_rgb)
复制代码

orange

show_colors(*cluster_main_color(orange_rgb))
复制代码

白色占据了最大比例

color rgb = [254.98760944 254.96645623 254.86318969], rate = 74.75%

color rgb = [247.01369994 184.5577787 30.02908848], rate = 8.34%

color rgb = [248.67170747 229.11155872 164.6937713 ], rate = 4.31%

color rgb = [243.82883803 135.17609999 5.85291513], rate = 7.86%

color rgb = [245.80695525 202.68227934 87.94639838], rate = 4.74%

orange_color1

该图片的背景是透明的,但是转换为3通道图片后背景变成了白色,对于这种情况,我们可以剔除全透明的像素点再统计。

def get_rgba_tensor(path):
    img = Image.open(path)
    img = img.convert("RGBA") # 通过转换确保图片为RGBA格式
    img_tensor = np.asarray(img)
    return img_tensor

def opaque_main_color(rgba_tensor, k=5):
    assert len(rgba_tensor.shape) == 3 
    assert rgba_tensor.shape[2] == 4
    pixels = rgba_tensor.reshape((-1, 4))
    opaque = (pixels[:, 3]== 255)
    opaque_pixels = pixels[opaque, :3]
    return pixels_cluster_main_color(opaque_pixels, k)
复制代码
orange_rgba = get_rgba_tensor("assets/orange.png")
show_colors(*opaque_main_color(orange_rgba))
复制代码

结果剔除了白色

color rgb = [246.90028579 190.46000861 42.7931723 ], rate = 29.64%

color rgb = [251.26230029 234.63030367 172.81790287], rate = 10.93%

color rgb = [237.11025562 115.31875577 8.2078842 ], rate = 14.93%

color rgb = [248.59684263 159.90122702 6.71614235], rate = 30.29%

color rgb = [247.3824633 207.87406199 98.8504894 ], rate = 14.20%

orange_color2

另外一个问题是,我们很容易发现,这种算法计算速度很慢,因为聚类算法有一个不断迭代的过程。有两个技巧可以加快计算速度,一是把图片尺寸缩小了再计算,二是将KMeans 算法换成更快的 MiniBatchKMeans等算法,这里就不再赘述相关代码了。