Gradio实现算法前端可视化

1,254 阅读3分钟

Gradio

最近看到一个新的包,感觉挺好玩的,这里记录一下。

与他人共享机器学习模型,API或数据科学工作流程的最佳方法之一是创建一个交互式应用程序,使用户或同事可以在浏览器中尝试演示。

Gradio允许在Python中构建演示并分享它们,通常只需要几行代码即可完成之前需要写flask后端服务,前端开发用户界面等复杂工作。

安装

pip install gradio

Hello World

开始学起都是从输出"hello world"开始,这里也不例外。

import gradio as gr

def greet(name):
    return "Hello " + name + "!"

demo = gr.Interface(fn=greet, inputs="text", outputs="text")

demo.launch()

上面的代码运行后,在http// localhost:7860上弹出浏览器:

在这里插入图片描述

左边输入对应的name,右边有基于程序的输出:

在这里插入图片描述

接口类

为了制作演示,上面创建了gradio.interface此接口类可以让用户界面包装任何Python函数。在上面的示例中,演示了一个简单的基于文本的功能,接收文本输入,并添加hello+该文本输出。此功能不仅仅局限于该功能,可以是任何函数,比如音乐生成器,税收计算器再到机器学习模型的推理功能的任何内容。

核心接口类Interface以三个必需的参数初始化:

  • fn:包裹UI的功能函数

  • input:用于输入的哪个组件(例如 "text", "image" 或者"audio"

  • output:用于输出的哪个组件(例如 "text", "image" 或者"audio"

组件属性:

  • textbox: 比text有更大的空间,且可以添加提示字符串 ,用法: inputs=gr.Textbox(lines=2, placeholder="Name Here...") 在这里插入图片描述

多输入和输出组件

假设输入和输出比较复杂,有多个的情况下,可以通过列表的方式进行传递参数。下面是一个计算器,实现加减乘除的功能,输入是两个数字,中间可以加四种运算符中的其中一种,在输入中只需要将其按照对应的格式排列好即可。[[5, "add", 3],] 表示5+3。可以用example表示默认示例。

import gradio as gr

def calculator(num1, operation, num2):
    if operation == "add":
        return num1 + num2
    elif operation == "subtract":
        return num1 - num2
    elif operation == "multiply":
        return num1 * num2
    elif operation == "divide":
        if num2 == 0:
            raise gr.Error("Cannot divide by zero!")
        return num1 / num2

demo = gr.Interface(
    calculator,
    [
        "number", 
        gr.Radio(["add", "subtract", "multiply", "divide"]),
        "number"
    ],
    "number",
    examples=[
        [5, "add", 3],
        [4, "divide", 2],
        [-4, "multiply", 2.5],
        [0, "subtract", 1.2],
    ],
    title="Toy Calculator",
    description="Here's a sample toy calculator. Enjoy!",
)
if __name__ == "__main__":
    demo.launch()

在这里插入图片描述

更多接口属性可以参考官方文档。

模型推理

模型推理可能是更加适合的场景。如果你是一名AI工程师或数据科学家,可能熟悉以下情况:刚刚花了数周的时间开发新的机器学习模型,终于对其性能感到满意,并且想向其它人展示它。这个时候如果只是单独的跑一下程序得到一个数字或者字符串可能不是那么直观,Gradio就能帮助我们省略掉写接口服务及前端演示界面的开发等环节。

下面介绍如何将分类模型PyTorch+ResNet152, 这块需要自己装一下PyTorch环境, CPU/GPU都可以

pip install torch==1.10.0 torchvision==0.11.0
import torch
import gradio as gr
from PIL import Image
from torchvision import transforms
import requests

# load model 机器要能联网,需要下载训练好的公开模型
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet152', pretrained=True).eval()


# load 1000类标签
labels = []
# labels.txt 可以通过`https://git.io/JJkYN`下载
with open('labels.txt') as f:
    for ln in f:
        label = ln.rstrip('\n')
        labels.append(label)


def predict(inp):
  inp = Image.fromarray(inp.astype('uint8'), 'RGB')
  inp = transforms.ToTensor()(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
  return {labels[i]: float(prediction[i]) for i in range(1000)}


inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=3)

# share = True 表示可以生成一个url链接,公众通过该url就能体验该功能,有效期72h
gr.Interface(fn=predict, inputs=inputs, outputs=outputs).launch(share=True)  

在这里插入图片描述

在这里插入图片描述

注意

如果想指定端口号且不想用127.0.0.1这个地址而是用原始的服务器ip地址,可以在demo.launch()中加一些指定的参数:

demo.launch(server_name='xx.xx.xx.xx', server_port=8080)

参考