Bioinspired diffusion LLMs
Collection
Collection of diffusion LLMs adapted for materials science and biology.
•
1 item
•
Updated
This model is a fine-tuned version of GSAI-ML/LLaDA-8B-Instruct on a a dataset of bio-inspired materials.
Make sure to install dLLM:
git clone https://github.com/ZHZisZZ/dllm.git
cd dllm
pip install -e .
Example inference:
from dataclasses import dataclass
import transformers
import dllm
from dllm.tools.chat import decode_trim
from dllm.pipelines import llada
''' #or log in using `huggingface-cli login`
token= 'hf_...'
from huggingface_hub import login
login(token=token)
'''
# ---------------------------------------------------------
# Load model + tokenizer
# ---------------------------------------------------------
@dataclass
class ScriptArguments:
model_name_or_path: str = "lamm-mit/LLaDA-8B-Bioinspired-dLLM-Instruct-11-21-2025"
def __post_init__(self):
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
script_args = ScriptArguments()
transformers.set_seed(42)
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
generator = llada.LLaDAGenerator(
model=model,
tokenizer=tokenizer,
)
gen_config = llada.LLaDAGeneratorConfig(
steps=256,
max_new_tokens=256,
block_length=32,
temperature=0.0,
remasking="low_confidence",
)
# ---------------------------------------------------------
# Batched inference step
# ---------------------------------------------------------
messages_batch = [
[{"role": "user", "content": "Explain materiomics briefly."}],
[{"role": "user", "content": "Define mechanobiology in one paragraph."}],
[{"role": "user", "content": "Why is silk stronger than elastin?"}],
]
inputs = tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=True,
tokenize=True,
)
outputs = generator.generate(
inputs,
gen_config,
return_dict_in_generate=True,
)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
# ---------------------------------------------------------
# Results
# ---------------------------------------------------------
for i, s in enumerate(sequences):
print("\n" + "-" * 70)
print(f"[Sample {i}]")
print("-" * 70)
print(s.strip())
Visualization:
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
tokenizer=tokenizer
)
terminal_visualizer.visualize(outputs.histories, rich=True)
Example to extract reasoning and design principles:
gen_config = llada.LLaDAGeneratorConfig(
steps=512,
max_new_tokens=512,
block_length=32,
temperature=0.2,
remasking="low_confidence",
)
masked_messages = [
[
{
"role": "user",
"content": (
"In spider-silk materiomics, we often optimize hierarchical structure "
"from amino-acid sequence to β-sheet nanocrystal arrangement. "
"Complete the missing reasoning steps for the following design question:\n\n"
f"**Design Problem:** How could one tune the fraction of β-sheet "
f"nanocrystals to increase toughness without compromising elasticity?\n\n"
f"Missing reasoning: {tokenizer.mask_token * 128}"
),
},
{
"role": "assistant",
"content": (
f"The summary is: {tokenizer.mask_token * 20}" #
),
},
],
[
{
"role": "user",
"content": (
"In nacre-inspired composite design, we often tune the architecture of "
"brick-and-mortar layers to balance stiffness, strength, and toughness. "
"Complete the missing reasoning steps for the following design question:\n\n"
"**Design Problem:** How could one introduce controlled mineral platelet "
"misalignment to enhance toughness while preserving high stiffness?\n\n"
f"Missing reasoning: {tokenizer.mask_token * 128}"
),
},
{
"role": "assistant",
"content": (
f"The design principle is: {tokenizer.mask_token * 20}"
),
},
]
]
# Tokenize input with NO generation prompt
inputs = tokenizer.apply_chat_template(
masked_messages,
add_generation_prompt=False,
tokenize=True,
)
# Infilling
outputs = generator.infill(inputs, gen_config, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
# Print results
for idx, (inp, filled) in enumerate(zip(inputs, sequences)):
print("\n" + "-" * 80)
print(f"[Case {idx}]")
print("-" * 80)
print("[Masked]:\n" + tokenizer.decode(inp))
print("\n[Filled]:\n" + (filled.strip() if filled.strip() else "<empty>"))
print("\n" + "=" * 80 + "\n")
terminal_visualizer.visualize(outputs.histories, rich=True)
Base model
GSAI-ML/LLaDA-8B-Instruct