Spaces:
Paused
Paused
File size: 5,614 Bytes
1ac2018 c07c20f 1ac2018 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os.path as osp
import numpy as np
from PIL import Image
import torch
import utils3d
import logging
import third_party.TRELLIS.trellis.modules.sparse as sp
from third_party.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline, TrellisTextTo3DPipeline
from lib.util import generation, partfield
# Global logger
log = logging.getLogger(__name__)
def attn_cosine_sim(x, eps=1e-08):
x = x[0] # TEMP: getting rid of redundant dimension, TBF
norm1 = x.norm(dim=2, keepdim=True)
factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
sim_matrix = (x @ x.permute(0, 2, 1)) / factor
return sim_matrix
def optimize_self_similarity(cfg, app, app_type, output_dir):
log.info("Starting self-similarity optimization...")
if app_type == 'image':
generation_pipeline = TrellisImageTo3DPipeline.from_pretrained(cfg.trellis_img_model_name)
app = Image.open(osp.join(output_dir, 'app_image.png')).convert('RGB')
app = generation_pipeline.preprocess_image(app)
else:
generation_pipeline = TrellisTextTo3DPipeline.from_pretrained(cfg.trellis_text_model_name)
generation_pipeline.cuda()
# Load Structure Data
struct_coords = utils3d.io.read_ply(osp.join(output_dir, 'voxels', 'struct_voxels.ply'))[0]
struct_coords = torch.from_numpy(struct_coords).float().cuda()
struct_coords = ((struct_coords + 0.5) * 64).long()
zeros = torch.zeros((struct_coords.size(0), 1), dtype=struct_coords.dtype, device=struct_coords.device)
struct_coords = torch.cat([zeros, struct_coords], dim=1)
# Load partfield planes
path = osp.join(output_dir, "partfield", "part_feat_struct_mesh_zup_batch_part_plane.npy")
struct_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda()
struct_labels = partfield.cluster_geoms(struct_coords, struct_part_planes, num_clusters=cfg.sim_guidance.num_part_clusters)
# Optimization Starts...
struct_labels = torch.from_numpy(struct_labels.flatten()).cuda()
struct_feats_params = torch.nn.Parameter(torch.randn((struct_coords.shape[0], cfg.flow_model_in_channels)), requires_grad=True)
param_list = [struct_feats_params]
optimizer = torch.optim.AdamW(param_list, lr=cfg.sim_guidance.learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1)
best_loss = float('inf')
feats = None
cond = generation_pipeline.get_cond([app])
flow_model = generation_pipeline.models['slat_flow_model']
if app_type == 'image':
sampler_params = {
"cfg_strength": cfg.img_model.cfg_strength,
"cfg_interval": cfg.img_model.cfg_interval,
}
rescale_t = cfg.img_model.rescale_t
else:
sampler_params = {
"cfg_strength": cfg.text_model.cfg_strength,
"cfg_interval": cfg.text_model.cfg_interval,
}
rescale_t = cfg.text_model.rescale_t
t_seq = np.linspace(1, 0, cfg.sim_guidance.steps + 1)
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(cfg.sim_guidance.steps))
std = torch.tensor(generation_pipeline.slat_normalization['std'])[None].cuda()
mean = torch.tensor(generation_pipeline.slat_normalization['mean'])[None].cuda()
log.info(f"Beginning self-similarity guidance + flow sampling loop for {len(t_pairs)} steps...")
for iteration, (t, t_prev) in enumerate(t_pairs):
optimizer.zero_grad()
# Diffusion
struct_feats_params_clone = struct_feats_params.clone().cuda()
noise = sp.SparseTensor(
feats = struct_feats_params_clone,
coords = struct_coords.int(),
).cuda()
with torch.no_grad():
out = generation_pipeline.slat_sampler.sample_once(flow_model, noise, t, t_prev, **cond, **sampler_params)
sample = out.pred_x_prev
struct_feats_params.data = sample.feats
# Optimization - Structure Loss
if iteration < len(t_pairs) - 1:
labels = struct_labels.view(-1,1)
sim = attn_cosine_sim(struct_feats_params[None, None, ...])[0]
mask = (labels == labels.T).float()
logits_mask = torch.ones_like(mask) - torch.eye(mask.size(0), device=struct_feats_params.device)
mask = mask * logits_mask
exp_sim = torch.exp(sim) * logits_mask
numerator = (exp_sim * mask).sum(dim=1)
denominator = exp_sim.sum(dim=1)
struct_loss = -torch.log(numerator / (denominator + 1e-8))
struct_loss = struct_loss[mask.sum(dim=1) > 0].mean()
total_loss = cfg.sim_guidance.loss_weight * struct_loss
total_loss.backward()
optimizer.step()
scheduler.step()
if (iteration == 0) or (iteration + 1) % cfg.log_every == 0:
message = f"Step: {iteration}, Structure Loss: {struct_loss.item():.4f}, Total Loss: {total_loss.item():.4f}"
log.info(message)
if total_loss < best_loss:
best_loss = total_loss.item()
feats = struct_feats_params.detach() * std + mean
# Decode SLAT
log.info("Decoding output SLAT...")
out_meshpath = osp.join(output_dir, 'out_sim.glb')
out_gspath = osp.join(output_dir, 'out_gaussian_sim.mp4')
generation.decode_slat(generation_pipeline, feats, struct_coords, out_meshpath, out_gspath) |