Dylan-Kaneshiro's picture
Update app.py
3bd038f verified
from mabwiser.mab import MAB, LearningPolicy
def bandit_factory(bandit_type, arms):
if bandit_type == "Epsilon Greedy":
result = MAB(arms=arms,
learning_policy=LearningPolicy.EpsilonGreedy(epsilon=0.3),
seed=1234)
elif bandit_type == "UCB":
result = MAB(arms=arms,
learning_policy=LearningPolicy.UCB1(alpha=1),
seed=1234)
elif bandit_type == "Non-Stationary":
result = NSBandit(arms=arms, epsilon=0.3, alpha=0.2)
else:
raise ValueError("Invalid bandit type")
result.partial_fit(decisions=arms, rewards=[3]*len(arms))
return result
class NSBandit:
def __init__(self, arms, epsilon, alpha):
self.arms = arms
self.epsilon = epsilon
self.alpha = alpha
self.means = {arm: None for arm in arms}
self.t = 0
def partial_fit(self, decisions, rewards):
for arm, reward in zip(decisions, rewards):
if self.means[arm] is None:
self.means[arm] = reward
else:
self.means[arm] += self.alpha * (reward - self.means[arm])
self.t += 1
def predict(self):
nones = [t[0] for t in self.means.items() if t[1] is None]
if len(nones) > 0:
return random.choice(nones)
best = max(self.means, key=self.means.get)
if random.random() < self.epsilon:
return random.choice(list(set(self.arms) - {best}))
else:
return max(self.means, key=self.means.get)
import gradio as gr
import random
import pandas as pd
# Load songs dataset
#file_path = "/content/drive/My Drive/MIT/RealTime/songs_single_genre.csv"
#file_path = "/content/songs_single_genre.csv"
file_path = "songs_final.csv"
song_df = pd.read_csv(file_path)
song_df.rename(columns={"name": "song"}, inplace=True)
# Drop missing values in key columns
song_df = song_df.dropna(subset=['genre', 'song', 'artist', 'img', 'preview'])
song_df = song_df[song_df['genre'].isin(['Hip-Hop', 'Pop', 'Reggaeton','Country'])]
# Get unique genres
genres = list(song_df['genre'].unique())
# Create a dictionary to store song details, including the image and preview
songs = {}
for genre in genres:
songs[genre] = list(song_df[song_df['genre'] == genre][['song', 'artist', 'img', 'preview']].itertuples(index=False))
# Initialize bandit
bandit = bandit_factory("Epsilon Greedy", list(songs.keys()))
# Initial song, genre, and media
initial_genre = "Pop"
initial_song, initial_artist, initial_img, initial_preview = random.choice(songs[initial_genre])
initial_song_display = f"{initial_song}\nBy {initial_artist}"
# Function to handle rating submission
def submit_rating(cur_genre, rating, MAB):
# Adjust MAB
MAB.partial_fit(decisions=[cur_genre], rewards=[rating])
# Get new genre
new_genre = MAB.predict()
# Select a random song from the new genre
new_song, new_artist, new_img, new_preview = random.choice(songs[new_genre])
# Format song text
new_song_display = f"{new_song}\nBy {new_artist}"
return new_genre, new_song_display, new_img, new_preview, MAB
# Function to change bandit model
def change_bandit(bandit_type):
return bandit_factory(bandit_type, list(songs.keys()))
# Gradio UI
with gr.Blocks() as demo:
# State variables
MAB = gr.State(bandit)
cur_genre = gr.State(initial_genre)
model = gr.Dropdown(
["Epsilon Greedy", "UCB", "Non-Stationary"], label="Bandit Model", interactive=True
)
# Image output component for album cover
image_output = gr.Image(value=initial_img, label="Album Cover", interactive=False)
# Text output component for song details
text_output = gr.Textbox(value=initial_song_display, label="Song", interactive=False)
# Audio preview component
audio_output = gr.Audio(value=initial_preview, label="Preview", interactive=False)
# Slider for rating
slider = gr.Slider(minimum=1, maximum=5, step=1, label="Rate the song (1-5)")
# Submit button
submit_btn = gr.Button("Submit")
# Change bandit model
model.input(
fn=change_bandit,
inputs=model,
outputs=MAB
)
# Submit button action
submit_btn.click(
fn=submit_rating,
inputs=[cur_genre, slider, MAB],
outputs=[cur_genre, text_output, image_output, audio_output, MAB]
)
demo.launch()