fishie-lee's picture
minor improvement
c7f6f2c
import gradio as gr
from PIL import Image
import numpy as np
import requests
import os
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_processor = AutoImageProcessor.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
model = AutoModelForImageClassification.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit").to(device)
labels = {int(k): v for k, v in model.config.id2label.items()}
print(labels)
def predict(image: Image.Image, k=10):
inputs = image_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)[0]
top_probs, top_idxs = torch.topk(probs, k=k)
top_probs = top_probs.cpu().tolist()
top_idxs = top_idxs.cpu().tolist()
def format_label(label):
return ' '.join(word.capitalize() for word in label.split('_'))
return {format_label(labels[i]): float(top_probs[j]) for j, i in enumerate(top_idxs)}
def classify(image: Image.Image) -> dict:
result = predict(image)
return result
title = "Dog Breed Classifier 🐶"
description = "A ViT model fine-tuned to classify images of dog breeds."
example_list = [os.path.join("examples", f) for f in os.listdir("examples") if f.endswith(('jpg', 'jpeg', 'png'))]
demo = gr.Interface(fn=classify,
inputs=gr.Image(type="pil"),
outputs="label",
title=title,
description=description,
examples=example_list)
if __name__ == "__main__":
demo.launch()