mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-21 03:44:49 +08:00
* re-add RL model code * match model forward api * add register_to_config, pass training tests * fix tests, update forward outputs * remove unused code, some comments * add to docs * remove extra embedding code * unify time embedding * remove conv1d output sequential * remove sequential from conv1dblock * style and deleting duplicated code * clean files * remove unused variables * clean variables * add 1d resnet block structure for downsample * rename as unet1d * fix renaming * rename files * add get_block(...) api * unify args for model1d like model2d * minor cleaning * fix docs * improve 1d resnet blocks * fix tests, remove permuts * fix style * add output activation * rename flax blocks file * Add Value Function and corresponding example script to Diffuser implementation (#884) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review Co-authored-by: Nathan Lambert <nathan@huggingface.co> * update post merge of scripts * add mdiblock / outblock architecture * Pipeline cleanup (#947) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src Co-authored-by: Nathan Lambert <nathan@huggingface.co> * Update src/diffusers/models/unet_1d_blocks.py * Update tests/test_models_unet.py * RL Cleanup v2 (#965) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src * add specific vf block and update tests * style * Update tests/test_models_unet.py Co-authored-by: Nathan Lambert <nathan@huggingface.co> * fix quality in tests * fix quality style, split test file * fix checks / tests * make timesteps closer to main * unify block API * unify forward api * delete lines in examples * style * examples style * all tests pass * make style * make dance_diff test pass * Refactoring RL PR (#1200) * init file changes * add import utils * finish cleaning files, imports * remove import flags * clean examples * fix imports, tests for merge * update readmes * hotfix for tests * quality * fix some tests * change defaults * more mps test fixes * unet1d defaults * do not default import experimental * defaults for tests * fix tests * fix-copies * fix * changes per Patrik's comments (#1285) * changes per Patrik's comments * update conversion script * fix renaming * skip more mps tests * last test fix * Update examples/rl/README.md Co-authored-by: Ben Glickenhaus <benglickenhaus@gmail.com>
101 lines
3.6 KiB
Python
101 lines
3.6 KiB
Python
import json
|
|
import os
|
|
|
|
import torch
|
|
|
|
from diffusers import UNet1DModel
|
|
|
|
|
|
os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
|
|
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)
|
|
|
|
os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)
|
|
|
|
|
|
def unet(hor):
|
|
if hor == 128:
|
|
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
|
|
block_out_channels = (32, 128, 256)
|
|
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")
|
|
|
|
elif hor == 32:
|
|
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
|
|
block_out_channels = (32, 64, 128, 256)
|
|
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
|
|
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
|
|
state_dict = model.state_dict()
|
|
config = dict(
|
|
down_block_types=down_block_types,
|
|
block_out_channels=block_out_channels,
|
|
up_block_types=up_block_types,
|
|
layers_per_block=1,
|
|
use_timestep_embedding=True,
|
|
out_block_type="OutConv1DBlock",
|
|
norm_num_groups=8,
|
|
downsample_each_block=False,
|
|
in_channels=14,
|
|
out_channels=14,
|
|
extra_in_channels=0,
|
|
time_embedding_type="positional",
|
|
flip_sin_to_cos=False,
|
|
freq_shift=1,
|
|
sample_size=65536,
|
|
mid_block_type="MidResTemporalBlock1D",
|
|
act_fn="mish",
|
|
)
|
|
hf_value_function = UNet1DModel(**config)
|
|
print(f"length of state dict: {len(state_dict.keys())}")
|
|
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
|
|
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
|
|
for k, v in mapping.items():
|
|
state_dict[v] = state_dict.pop(k)
|
|
hf_value_function.load_state_dict(state_dict)
|
|
|
|
torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
|
|
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
|
|
json.dump(config, f)
|
|
|
|
|
|
def value_function():
|
|
config = dict(
|
|
in_channels=14,
|
|
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
|
|
up_block_types=(),
|
|
out_block_type="ValueFunction",
|
|
mid_block_type="ValueFunctionMidBlock1D",
|
|
block_out_channels=(32, 64, 128, 256),
|
|
layers_per_block=1,
|
|
downsample_each_block=True,
|
|
sample_size=65536,
|
|
out_channels=14,
|
|
extra_in_channels=0,
|
|
time_embedding_type="positional",
|
|
use_timestep_embedding=True,
|
|
flip_sin_to_cos=False,
|
|
freq_shift=1,
|
|
norm_num_groups=8,
|
|
act_fn="mish",
|
|
)
|
|
|
|
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
|
|
state_dict = model
|
|
hf_value_function = UNet1DModel(**config)
|
|
print(f"length of state dict: {len(state_dict.keys())}")
|
|
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
|
|
|
|
mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
|
|
for k, v in mapping.items():
|
|
state_dict[v] = state_dict.pop(k)
|
|
|
|
hf_value_function.load_state_dict(state_dict)
|
|
|
|
torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
|
|
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
|
|
json.dump(config, f)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unet(32)
|
|
# unet(128)
|
|
value_function()
|