Spaces:
Running
Running
Update neuroprompt_deep.py
Browse files- neuroprompt_deep.py +9 -9
neuroprompt_deep.py
CHANGED
|
@@ -10,7 +10,7 @@ logging.basicConfig(level=logging.INFO)
|
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
class NeuroPromptDeep:
|
| 13 |
-
def __init__(self, model_name: str = "gpt2"):
|
| 14 |
"""
|
| 15 |
Initialize the generative AI engine with a local language model.
|
| 16 |
"""
|
|
@@ -74,20 +74,20 @@ class NeuroPromptDeep:
|
|
| 74 |
cache_dir=self.cache_dir
|
| 75 |
)
|
| 76 |
|
| 77 |
-
# CPU-
|
| 78 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 79 |
self.model_name,
|
| 80 |
-
cache_dir=self.cache_dir
|
| 81 |
-
device_map="auto", # Will automatically use CPU
|
| 82 |
-
low_cpu_mem_usage=True # Reduces memory footprint
|
| 83 |
)
|
| 84 |
|
| 85 |
-
# CPU
|
|
|
|
|
|
|
|
|
|
| 86 |
self.generator = pipeline(
|
| 87 |
"text-generation",
|
| 88 |
model=self.model,
|
| 89 |
-
tokenizer=self.tokenizer
|
| 90 |
-
device=-1 # Force CPU usage
|
| 91 |
)
|
| 92 |
|
| 93 |
logger.info("Model loaded successfully for CPU!")
|
|
@@ -117,7 +117,7 @@ class NeuroPromptDeep:
|
|
| 117 |
# Format the full prompt
|
| 118 |
full_prompt = f"{system_prompt}\n<|user|>\n{prompt}</s>\n<|assistant|>\n"
|
| 119 |
|
| 120 |
-
# Generate response
|
| 121 |
outputs = self.generator(
|
| 122 |
full_prompt,
|
| 123 |
max_new_tokens=settings["max_length"],
|
|
|
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
class NeuroPromptDeep:
|
| 13 |
+
def __init__(self, model_name: str = "gpt2"):
|
| 14 |
"""
|
| 15 |
Initialize the generative AI engine with a local language model.
|
| 16 |
"""
|
|
|
|
| 74 |
cache_dir=self.cache_dir
|
| 75 |
)
|
| 76 |
|
| 77 |
+
# SIMPLIFIED CPU-ONLY CONFIGURATION
|
| 78 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 79 |
self.model_name,
|
| 80 |
+
cache_dir=self.cache_dir
|
|
|
|
|
|
|
| 81 |
)
|
| 82 |
|
| 83 |
+
# Explicitly move model to CPU
|
| 84 |
+
self.model = self.model.to('cpu')
|
| 85 |
+
|
| 86 |
+
# CPU-based pipeline without device conflict
|
| 87 |
self.generator = pipeline(
|
| 88 |
"text-generation",
|
| 89 |
model=self.model,
|
| 90 |
+
tokenizer=self.tokenizer
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
logger.info("Model loaded successfully for CPU!")
|
|
|
|
| 117 |
# Format the full prompt
|
| 118 |
full_prompt = f"{system_prompt}\n<|user|>\n{prompt}</s>\n<|assistant|>\n"
|
| 119 |
|
| 120 |
+
# Generate response with CPU-friendly settings
|
| 121 |
outputs = self.generator(
|
| 122 |
full_prompt,
|
| 123 |
max_new_tokens=settings["max_length"],
|