vedaco commited on
Commit
47c9da0
·
verified ·
1 Parent(s): 304f12d

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +43 -933
train.py CHANGED
@@ -1,955 +1,65 @@
1
- """Training for Veda Programming Assistant with Distillation Support"""
2
-
3
  import tensorflow as tf
4
- from tensorflow import keras
5
- import numpy as np
6
- import os
7
- import json
8
-
9
  from model import VedaProgrammingLLM
10
  from tokenizer import VedaTokenizer
11
- from database import db
12
- from config import MODEL_DIR, VOCAB_SIZE, MAX_LENGTH, BATCH_SIZE
13
-
14
 
15
- # Training data with proper format (no markdown code blocks - using <CODE> tags)
16
  TRAINING_DATA = """
17
- <USER> Hello!
18
- <ASSISTANT> Hello! Welcome to Veda Programming Assistant! I can help you with coding, explain concepts, or answer programming questions. What would you like to learn today?
19
-
20
- <USER> What is Python?
21
- <ASSISTANT> Python is a high-level, interpreted programming language created by Guido van Rossum in 1991. It is known for its simple, readable syntax.
22
-
23
- Key features:
24
- - Easy to learn and read
25
- - Dynamic typing
26
- - Large standard library
27
- - Great for beginners and experts
28
-
29
- Python is used for web development, data science, AI, automation, and more!
30
-
31
- <USER> How do I print hello world?
32
- <ASSISTANT> To print Hello World in Python, use the print function:
33
-
34
- <CODE>
35
- print("Hello, World!")
36
- <ENDCODE>
37
-
38
- When you run this, it displays: Hello, World!
39
-
40
- <USER> Write a function to add two numbers
41
- <ASSISTANT> Here is a simple function to add two numbers:
42
-
43
- <CODE>
44
- def add(a, b):
45
- return a + b
46
-
47
- result = add(5, 3)
48
- print(result)
49
- <ENDCODE>
50
-
51
- Output: 8
52
-
53
- This function takes two parameters and returns their sum.
54
-
55
- <USER> Explain what a function is
56
- <ASSISTANT> A function is a reusable block of code that performs a specific task. Functions help organize code and avoid repetition.
57
-
58
- Here is the basic structure:
59
-
60
- <CODE>
61
- def function_name(parameters):
62
- # Code block
63
- return result
64
- <ENDCODE>
65
-
66
- Parts of a function:
67
- - def: keyword to define a function
68
- - function_name: what you call your function
69
- - parameters: input values
70
- - return: sends back a result
71
-
72
- Example:
73
-
74
- <CODE>
75
- def greet(name):
76
- message = f"Hello, {name}!"
77
- return message
78
-
79
- print(greet("Alice"))
80
- <ENDCODE>
81
-
82
- Output: Hello, Alice!
83
-
84
- <USER> What is a list?
85
- <ASSISTANT> A list is an ordered, mutable collection in Python that can hold items of any type.
86
-
87
- Creating lists:
88
-
89
- <CODE>
90
- my_list = []
91
- numbers = [1, 2, 3, 4, 5]
92
- fruits = ["apple", "banana", "cherry"]
93
- mixed = [1, "hello", 3.14, True]
94
- <ENDCODE>
95
-
96
- Common operations:
97
-
98
- <CODE>
99
- fruits = ["apple", "banana"]
100
-
101
- fruits.append("cherry")
102
- fruits.insert(0, "apricot")
103
-
104
- print(fruits[0])
105
- print(fruits[-1])
106
-
107
- fruits.remove("banana")
108
- fruits.pop()
109
-
110
- print(len(fruits))
111
- <ENDCODE>
112
-
113
- <USER> Write a fibonacci function
114
- <ASSISTANT> Here is a Fibonacci function:
115
-
116
- Recursive approach:
117
-
118
- <CODE>
119
- def fibonacci(n):
120
- if n <= 0:
121
- return 0
122
- if n == 1:
123
- return 1
124
- return fibonacci(n - 1) + fibonacci(n - 2)
125
-
126
- for i in range(10):
127
- print(fibonacci(i), end=" ")
128
- <ENDCODE>
129
-
130
- Output: 0 1 1 2 3 5 8 13 21 34
131
-
132
- Iterative approach (more efficient):
133
-
134
- <CODE>
135
- def fibonacci_iterative(n):
136
- if n <= 0:
137
- return 0
138
- if n == 1:
139
- return 1
140
- a, b = 0, 1
141
- for _ in range(2, n + 1):
142
- a, b = b, a + b
143
- return b
144
- <ENDCODE>
145
-
146
- <USER> What is recursion?
147
- <ASSISTANT> Recursion is when a function calls itself to solve a problem.
148
-
149
- Key components:
150
- 1. Base case - stops the recursion
151
- 2. Recursive case - function calls itself
152
-
153
- Example - countdown:
154
-
155
- <CODE>
156
- def countdown(n):
157
- if n <= 0:
158
- print("Done!")
159
- return
160
- print(n)
161
- countdown(n - 1)
162
-
163
- countdown(5)
164
- <ENDCODE>
165
-
166
- Output: 5, 4, 3, 2, 1, Done!
167
-
168
- Factorial example:
169
-
170
- <CODE>
171
- def factorial(n):
172
- if n <= 1:
173
- return 1
174
- return n * factorial(n - 1)
175
-
176
- print(factorial(5))
177
- <ENDCODE>
178
-
179
- Output: 120
180
-
181
- <USER> Write bubble sort
182
- <ASSISTANT> Here is the Bubble Sort algorithm:
183
-
184
- <CODE>
185
- def bubble_sort(arr):
186
- n = len(arr)
187
- for i in range(n):
188
- swapped = False
189
- for j in range(0, n - i - 1):
190
- if arr[j] > arr[j + 1]:
191
- arr[j], arr[j + 1] = arr[j + 1], arr[j]
192
- swapped = True
193
- if not swapped:
194
- break
195
- return arr
196
-
197
- numbers = [64, 34, 25, 12, 22, 11, 90]
198
- print("Original:", numbers)
199
- print("Sorted:", bubble_sort(numbers.copy()))
200
- <ENDCODE>
201
-
202
- Output:
203
- Original: [64, 34, 25, 12, 22, 11, 90]
204
- Sorted: [11, 12, 22, 25, 34, 64, 90]
205
-
206
- How it works: repeatedly compares adjacent elements and swaps them if in wrong order. Largest elements bubble up to the end.
207
-
208
- Time complexity: O(n^2)
209
-
210
- <USER> Write binary search
211
- <ASSISTANT> Here is the Binary Search algorithm:
212
-
213
- <CODE>
214
- def binary_search(arr, target):
215
- left = 0
216
- right = len(arr) - 1
217
-
218
- while left <= right:
219
- mid = (left + right) // 2
220
-
221
- if arr[mid] == target:
222
- return mid
223
- elif arr[mid] < target:
224
- left = mid + 1
225
- else:
226
- right = mid - 1
227
-
228
- return -1
229
-
230
- numbers = [1, 3, 5, 7, 9, 11, 13, 15]
231
- print(binary_search(numbers, 7))
232
- print(binary_search(numbers, 6))
233
- <ENDCODE>
234
-
235
- Output:
236
- 3
237
- -1
238
-
239
- Important: The array must be sorted!
240
- Time complexity: O(log n)
241
-
242
- <USER> What is a class?
243
- <ASSISTANT> A class is a blueprint for creating objects. It bundles data (attributes) and functions (methods) together.
244
-
245
- <CODE>
246
- class Dog:
247
- def __init__(self, name, age):
248
- self.name = name
249
- self.age = age
250
-
251
- def bark(self):
252
- return f"{self.name} says Woof!"
253
-
254
- def get_info(self):
255
- return f"{self.name} is {self.age} years old"
256
-
257
- my_dog = Dog("Buddy", 3)
258
- print(my_dog.bark())
259
- print(my_dog.get_info())
260
- <ENDCODE>
261
-
262
- Output:
263
- Buddy says Woof!
264
- Buddy is 3 years old
265
-
266
- Key concepts:
267
- - __init__: constructor, runs when object is created
268
- - self: refers to the current object
269
- - Attributes: variables inside the class
270
- - Methods: functions inside the class
271
-
272
- <USER> How do I read a file?
273
- <ASSISTANT> Here is how to read files in Python:
274
-
275
- Read entire file:
276
-
277
- <CODE>
278
- with open("filename.txt", "r") as file:
279
- content = file.read()
280
- print(content)
281
- <ENDCODE>
282
-
283
- Read line by line:
284
-
285
- <CODE>
286
- with open("filename.txt", "r") as file:
287
- for line in file:
288
- print(line.strip())
289
- <ENDCODE>
290
-
291
- Read all lines into list:
292
-
293
- <CODE>
294
- with open("filename.txt", "r") as file:
295
- lines = file.readlines()
296
- <ENDCODE>
297
-
298
- With error handling:
299
-
300
- <CODE>
301
- def read_file_safely(filename):
302
- try:
303
- with open(filename, "r") as file:
304
- return file.read()
305
- except FileNotFoundError:
306
- print(f"File not found: {filename}")
307
- return None
308
-
309
- content = read_file_safely("myfile.txt")
310
- <ENDCODE>
311
-
312
- The with statement automatically closes the file when done.
313
-
314
- <USER> How do I write to a file?
315
- <ASSISTANT> Here is how to write to files in Python:
316
-
317
- Write (overwrites existing content):
318
-
319
- <CODE>
320
- with open("output.txt", "w") as file:
321
- file.write("Hello, World!")
322
- file.write("This is a new line")
323
- <ENDCODE>
324
-
325
- Append (adds to existing content):
326
-
327
- <CODE>
328
- with open("output.txt", "a") as file:
329
- file.write("This is appended")
330
- <ENDCODE>
331
-
332
- Write multiple lines:
333
-
334
- <CODE>
335
- lines = ["Line 1", "Line 2", "Line 3"]
336
- with open("output.txt", "w") as file:
337
- for line in lines:
338
- file.write(line + "\\n")
339
- <ENDCODE>
340
-
341
- Or use writelines:
342
-
343
- <CODE>
344
- lines = ["Line 1\\n", "Line 2\\n", "Line 3\\n"]
345
- with open("output.txt", "w") as file:
346
- file.writelines(lines)
347
- <ENDCODE>
348
-
349
- <USER> How do I handle errors?
350
- <ASSISTANT> Python uses try-except blocks for error handling:
351
-
352
- Basic try-except:
353
-
354
- <CODE>
355
- try:
356
- result = 10 / 0
357
- except ZeroDivisionError:
358
- print("Cannot divide by zero!")
359
- <ENDCODE>
360
-
361
- Multiple exceptions:
362
-
363
- <CODE>
364
- try:
365
- number = int(input("Enter number: "))
366
- result = 10 / number
367
- except ValueError:
368
- print("Not a valid number!")
369
- except ZeroDivisionError:
370
- print("Cannot divide by zero!")
371
- except Exception as e:
372
- print(f"Error: {e}")
373
- <ENDCODE>
374
-
375
- Complete structure:
376
-
377
- <CODE>
378
- try:
379
- result = risky_operation()
380
- except SomeError:
381
- print("Handle error")
382
- else:
383
- print("Success! No error occurred")
384
- finally:
385
- print("This always runs")
386
- <ENDCODE>
387
-
388
- Common exceptions:
389
- - ValueError: wrong value
390
- - TypeError: wrong type
391
- - IndexError: index out of range
392
- - KeyError: dictionary key not found
393
- - FileNotFoundError: file does not exist
394
- - ZeroDivisionError: division by zero
395
-
396
- <USER> What is a dictionary?
397
- <ASSISTANT> A dictionary is a collection of key-value pairs in Python.
398
-
399
- Creating dictionaries:
400
-
401
- <CODE>
402
- empty_dict = {}
403
- person = {"name": "Alice", "age": 25, "city": "NYC"}
404
- another = dict(name="Bob", age=30)
405
- <ENDCODE>
406
-
407
- Accessing values:
408
-
409
- <CODE>
410
- person = {"name": "Alice", "age": 25}
411
-
412
- print(person["name"])
413
- print(person.get("name"))
414
- print(person.get("job", "Unknown"))
415
- <ENDCODE>
416
-
417
- Modifying:
418
-
419
- <CODE>
420
- person = {"name": "Alice", "age": 25}
421
-
422
- person["city"] = "Boston"
423
- person["age"] = 26
424
- del person["city"]
425
- age = person.pop("age")
426
- <ENDCODE>
427
-
428
- Iterating:
429
-
430
- <CODE>
431
- person = {"name": "Alice", "age": 25}
432
-
433
- for key in person:
434
- print(key)
435
-
436
- for value in person.values():
437
- print(value)
438
-
439
- for key, value in person.items():
440
- print(f"{key}: {value}")
441
- <ENDCODE>
442
-
443
- <USER> What is the difference between list and tuple?
444
- <ASSISTANT> The main difference is mutability:
445
- - Lists are mutable (can be changed)
446
- - Tuples are immutable (cannot be changed)
447
-
448
- Syntax:
449
-
450
- <CODE>
451
- my_list = [1, 2, 3]
452
- my_tuple = (1, 2, 3)
453
- <ENDCODE>
454
-
455
- Mutability difference:
456
-
457
- <CODE>
458
- my_list = [1, 2, 3]
459
- my_list[0] = 100
460
- my_list.append(4)
461
- print(my_list)
462
-
463
- my_tuple = (1, 2, 3)
464
- # my_tuple[0] = 100 # ERROR! Cannot modify
465
- # my_tuple.append(4) # ERROR! No append method
466
- <ENDCODE>
467
-
468
- When to use:
469
- - List: data that will change (shopping cart, scores)
470
- - Tuple: fixed data (coordinates, RGB colors)
471
-
472
- Examples:
473
-
474
- <CODE>
475
- coordinates = (10, 20)
476
- rgb_color = (255, 128, 0)
477
-
478
- shopping_cart = ["apple", "bread"]
479
- shopping_cart.append("milk")
480
- <ENDCODE>
481
-
482
- Other differences:
483
- - Tuples are slightly faster
484
- - Tuples use less memory
485
- - Tuples can be dictionary keys
486
-
487
- <USER> How do I use list comprehension?
488
- <ASSISTANT> List comprehension is a concise way to create lists.
489
-
490
- Basic syntax:
491
-
492
- <CODE>
493
- new_list = [expression for item in iterable]
494
- <ENDCODE>
495
-
496
- Examples:
497
-
498
- <CODE>
499
- squares = [x ** 2 for x in range(5)]
500
- print(squares)
501
- <ENDCODE>
502
-
503
- Output: [0, 1, 4, 9, 16]
504
-
505
- With condition:
506
-
507
- <CODE>
508
- evens = [x for x in range(10) if x % 2 == 0]
509
- print(evens)
510
- <ENDCODE>
511
-
512
- Output: [0, 2, 4, 6, 8]
513
-
514
- Transform strings:
515
-
516
- <CODE>
517
- words = ["hello", "world"]
518
- upper_words = [word.upper() for word in words]
519
- print(upper_words)
520
- <ENDCODE>
521
-
522
- Output: ["HELLO", "WORLD"]
523
-
524
- With if-else:
525
-
526
- <CODE>
527
- numbers = [1, 2, 3, 4, 5]
528
- labels = ["even" if x % 2 == 0 else "odd" for x in numbers]
529
- print(labels)
530
- <ENDCODE>
531
-
532
- Output: ["odd", "even", "odd", "even", "odd"]
533
-
534
- <USER> Write quicksort
535
- <ASSISTANT> Here is the QuickSort algorithm:
536
-
537
- <CODE>
538
- def quicksort(arr):
539
- if len(arr) <= 1:
540
- return arr
541
-
542
- pivot = arr[len(arr) // 2]
543
- left = [x for x in arr if x < pivot]
544
- middle = [x for x in arr if x == pivot]
545
- right = [x for x in arr if x > pivot]
546
-
547
- return quicksort(left) + middle + quicksort(right)
548
-
549
- numbers = [64, 34, 25, 12, 22, 11, 90]
550
- print("Original:", numbers)
551
- print("Sorted:", quicksort(numbers))
552
- <ENDCODE>
553
-
554
- Output:
555
- Original: [64, 34, 25, 12, 22, 11, 90]
556
- Sorted: [11, 12, 22, 25, 34, 64, 90]
557
-
558
- In-place version:
559
-
560
- <CODE>
561
- def quicksort_inplace(arr, low, high):
562
- if low < high:
563
- pivot_idx = partition(arr, low, high)
564
- quicksort_inplace(arr, low, pivot_idx - 1)
565
- quicksort_inplace(arr, pivot_idx + 1, high)
566
-
567
- def partition(arr, low, high):
568
- pivot = arr[high]
569
- i = low - 1
570
- for j in range(low, high):
571
- if arr[j] <= pivot:
572
- i += 1
573
- arr[i], arr[j] = arr[j], arr[i]
574
- arr[i + 1], arr[high] = arr[high], arr[i + 1]
575
- return i + 1
576
-
577
- numbers = [64, 34, 25, 12, 22, 11, 90]
578
- quicksort_inplace(numbers, 0, len(numbers) - 1)
579
- print(numbers)
580
- <ENDCODE>
581
-
582
- Time complexity:
583
- - Average: O(n log n)
584
- - Worst: O(n^2)
585
-
586
- <USER> Write merge sort
587
- <ASSISTANT> Here is the Merge Sort algorithm:
588
-
589
- <CODE>
590
- def merge_sort(arr):
591
- if len(arr) <= 1:
592
- return arr
593
-
594
- mid = len(arr) // 2
595
- left = merge_sort(arr[:mid])
596
- right = merge_sort(arr[mid:])
597
-
598
- return merge(left, right)
599
-
600
- def merge(left, right):
601
- result = []
602
- i = j = 0
603
-
604
- while i < len(left) and j < len(right):
605
- if left[i] <= right[j]:
606
- result.append(left[i])
607
- i += 1
608
- else:
609
- result.append(right[j])
610
- j += 1
611
-
612
- result.extend(left[i:])
613
- result.extend(right[j:])
614
- return result
615
-
616
- numbers = [64, 34, 25, 12, 22, 11, 90]
617
- print("Original:", numbers)
618
- print("Sorted:", merge_sort(numbers))
619
- <ENDCODE>
620
-
621
- Output:
622
- Original: [64, 34, 25, 12, 22, 11, 90]
623
- Sorted: [11, 12, 22, 25, 34, 64, 90]
624
-
625
- How it works:
626
- 1. Divide array into two halves
627
- 2. Recursively sort each half
628
- 3. Merge the sorted halves
629
-
630
- Time complexity: O(n log n) always
631
- Space complexity: O(n)
632
-
633
- <USER> What is a lambda function?
634
- <ASSISTANT> A lambda function is a small anonymous function defined in one line.
635
-
636
- Syntax:
637
-
638
- <CODE>
639
- lambda arguments: expression
640
- <ENDCODE>
641
-
642
- Examples:
643
-
644
- <CODE>
645
- add = lambda x, y: x + y
646
- print(add(5, 3))
647
-
648
- square = lambda x: x ** 2
649
- print(square(4))
650
-
651
- is_even = lambda x: x % 2 == 0
652
- print(is_even(4))
653
- <ENDCODE>
654
-
655
- Output:
656
- 8
657
- 16
658
- True
659
-
660
- Common uses:
661
-
662
- With sorted():
663
-
664
- <CODE>
665
- students = [("Alice", 85), ("Bob", 92), ("Charlie", 78)]
666
- sorted_by_score = sorted(students, key=lambda x: x[1])
667
- print(sorted_by_score)
668
- <ENDCODE>
669
-
670
- With map():
671
-
672
- <CODE>
673
- numbers = [1, 2, 3, 4, 5]
674
- squares = list(map(lambda x: x ** 2, numbers))
675
- print(squares)
676
- <ENDCODE>
677
-
678
- With filter():
679
-
680
- <CODE>
681
- numbers = [1, 2, 3, 4, 5, 6]
682
- evens = list(filter(lambda x: x % 2 == 0, numbers))
683
- print(evens)
684
- <ENDCODE>
685
-
686
- <USER> What are decorators?
687
- <ASSISTANT> Decorators are functions that modify the behavior of other functions.
688
-
689
- Basic decorator:
690
-
691
- <CODE>
692
- def my_decorator(func):
693
- def wrapper():
694
- print("Before function")
695
- func()
696
- print("After function")
697
- return wrapper
698
-
699
- @my_decorator
700
- def say_hello():
701
- print("Hello!")
702
-
703
- say_hello()
704
- <ENDCODE>
705
 
706
- Output:
707
- Before function
708
- Hello!
709
- After function
710
 
711
- With arguments:
 
712
 
713
- <CODE>
714
- def my_decorator(func):
715
- def wrapper(*args, **kwargs):
716
- print("Before function")
717
- result = func(*args, **kwargs)
718
- print("After function")
719
- return result
720
- return wrapper
721
 
722
- @my_decorator
 
723
  def add(a, b):
724
  return a + b
 
725
 
726
- print(add(5, 3))
727
- <ENDCODE>
728
-
729
- Practical example - timer:
730
-
731
- <CODE>
732
- import time
733
-
734
- def timer(func):
735
- def wrapper(*args, **kwargs):
736
- start = time.time()
737
- result = func(*args, **kwargs)
738
- end = time.time()
739
- print(f"{func.__name__} took {end - start:.4f} seconds")
740
- return result
741
- return wrapper
742
-
743
- @timer
744
- def slow_function():
745
- time.sleep(1)
746
- return "Done"
747
-
748
- slow_function()
749
- <ENDCODE>
750
-
751
- <USER> Thanks!
752
- <ASSISTANT> You are welcome! I am glad I could help. Feel free to ask me anything anytime. Happy coding!
753
-
754
- <USER> Hi
755
- <ASSISTANT> Hello! How can I help you today? I can write code, explain concepts, or answer your programming questions!
756
-
757
- <USER> What can you do?
758
- <ASSISTANT> I am Veda Programming Assistant! I can help you with:
759
-
760
- 1. Writing Python code
761
- 2. Explaining programming concepts
762
- 3. Debugging code
763
- 4. Answering questions about algorithms
764
- 5. Teaching data structures
765
- 6. Helping with file handling
766
- 7. Explaining object-oriented programming
767
- 8. And much more!
768
-
769
- Just ask me anything about programming and I will do my best to help you!
770
-
771
- <USER> 2+2=?
772
- <ASSISTANT> 4
773
-
774
- <USER> What is 10 times 5?
775
- <ASSISTANT> 10 times 5 equals 50.
776
-
777
- <USER> Calculate 100 divided by 4
778
- <ASSISTANT> 100 divided by 4 equals 25.
779
  """
780
 
781
-
782
  class VedaTrainer:
783
- """Trainer for Veda Programming Assistant with Distillation Support"""
784
-
785
- def __init__(self, vocab_size: int = 8000, max_length: int = 512, batch_size: int = 4):
786
- self.vocab_size = vocab_size
787
- self.max_length = max_length
788
- self.batch_size = batch_size
789
- self.tokenizer = VedaTokenizer(vocab_size=vocab_size)
790
  self.model = None
791
 
792
- def prepare_data(self, extra_data: str = "", distillation_data: str = ""):
793
- """Prepare training data including distillation data"""
794
- data = TRAINING_DATA
795
-
796
- if extra_data:
797
- data += "\n\n" + extra_data
798
-
799
- if distillation_data:
800
- data += "\n\n" + distillation_data
801
-
802
- if os.path.exists("programming.txt"):
803
- try:
804
- with open("programming.txt", "r", encoding="utf-8") as f:
805
- code_data = f.read()
806
- data += "\n\n" + code_data
807
- except Exception as e:
808
- print(f"Warning: Could not read programming.txt: {e}")
809
-
810
  self.tokenizer.fit([data])
811
-
812
- all_tokens = self.tokenizer.encode(data)
813
- print(f"Total tokens: {len(all_tokens)}")
814
-
815
- sequences = []
816
- stride = self.max_length // 2
817
-
818
- for i in range(0, len(all_tokens) - self.max_length - 1, stride):
819
- seq = all_tokens[i : i + self.max_length + 1]
820
- if len(seq) == self.max_length + 1:
821
- sequences.append(seq)
822
-
823
- if len(sequences) < 10:
824
- stride = self.max_length // 4
825
- sequences = []
826
- for i in range(0, len(all_tokens) - self.max_length - 1, stride):
827
- seq = all_tokens[i : i + self.max_length + 1]
828
- if len(seq) == self.max_length + 1:
829
- sequences.append(seq)
830
-
831
- print(f"Created {len(sequences)} training sequences")
832
-
833
- if len(sequences) == 0:
834
- print("Warning: No sequences created. Using minimal sequence.")
835
- min_seq = all_tokens[:self.max_length + 1]
836
- while len(min_seq) < self.max_length + 1:
837
- min_seq.append(0)
838
- sequences = [min_seq]
839
-
840
- sequences = np.array(sequences)
841
- X = sequences[:, :-1]
842
- y = sequences[:, 1:]
843
-
844
- dataset = tf.data.Dataset.from_tensor_slices((X, y))
845
- dataset = dataset.shuffle(1000).batch(self.batch_size).prefetch(1)
846
-
847
- return dataset
848
-
849
- def build_model(self):
850
- """Build the model"""
851
- self.model = VedaProgrammingLLM(
852
- vocab_size=self.tokenizer.vocabulary_size,
853
- max_length=self.max_length,
854
- d_model=256,
855
- num_heads=8,
856
- num_layers=4,
857
- ff_dim=512,
858
- )
859
-
860
- self.model.compile(
861
- optimizer=keras.optimizers.Adam(learning_rate=1e-4),
862
- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
863
- metrics=["accuracy"],
864
- )
865
-
866
- dummy = tf.zeros((1, self.max_length), dtype=tf.int32)
867
- self.model(dummy)
868
-
869
- return self.model
870
-
871
- def train(
872
- self,
873
- epochs: int = 15,
874
- save_path: str = None,
875
- extra_data: str = "",
876
- distillation_data: str = "",
877
- ):
878
- """Train the model"""
879
- if save_path is None:
880
- save_path = MODEL_DIR
881
-
882
- dataset = self.prepare_data(extra_data, distillation_data)
883
- self.build_model()
884
-
885
- self.model.summary()
886
-
887
- os.makedirs(save_path, exist_ok=True)
888
-
889
- history = self.model.fit(dataset, epochs=epochs, verbose=1)
890
-
891
- # Save weights
892
- self.model.save_weights(os.path.join(save_path, "weights.h5"))
893
 
894
- # Save tokenizer
895
- self.tokenizer.save(os.path.join(save_path, "tokenizer.json"))
896
-
897
- # Save config
898
- config = self.model.get_config()
899
- with open(os.path.join(save_path, "config.json"), "w") as f:
900
- json.dump(config, f, indent=2)
901
-
902
- print(f"Model saved to {save_path}")
903
- return history
904
-
905
- def generate_response(
906
- self, user_input: str, max_tokens: int = 200, temperature: float = 0.7
907
- ) -> str:
908
- """Generate a response"""
909
- if self.model is None:
910
- return "Model not loaded."
911
-
912
- prompt = f"<USER> {user_input}\n<ASSISTANT>"
913
-
914
- tokens = self.tokenizer.encode(prompt)
915
-
916
- generated = self.model.generate(
917
- tokens,
918
- max_new_tokens=max_tokens,
919
- temperature=temperature,
920
- repetition_penalty=1.2,
921
- )
922
-
923
- response = self.tokenizer.decode(generated)
924
-
925
- if "<ASSISTANT>" in response:
926
- response = response.split("<ASSISTANT>")[-1].strip()
927
- if "<USER>" in response:
928
- response = response.split("<USER>")[0].strip()
929
-
930
- return response
931
 
 
 
 
 
 
 
 
 
 
 
 
 
932
 
933
  if __name__ == "__main__":
934
- print("=" * 50)
935
- print("Training Veda Programming Assistant")
936
- print("=" * 50)
937
-
938
- trainer = VedaTrainer()
939
- trainer.train(epochs=20)
940
-
941
- print("\n" + "=" * 50)
942
- print("Testing the model:")
943
- print("=" * 50)
944
-
945
- test_prompts = [
946
- "Hello!",
947
- "What is a function?",
948
- "Write a function to reverse a string",
949
- "2+2=?",
950
- ]
951
-
952
- for prompt in test_prompts:
953
- print(f"\nUser: {prompt}")
954
- response = trainer.generate_response(prompt)
955
- print(f"Assistant: {response}")
 
 
 
1
  import tensorflow as tf
 
 
 
 
 
2
  from model import VedaProgrammingLLM
3
  from tokenizer import VedaTokenizer
4
+ from config import MODEL_DIR, VOCAB_SIZE, MAX_LENGTH
5
+ import os
6
+ import json
7
 
 
8
  TRAINING_DATA = """
9
+ <USER> Hi
10
+ <ASSISTANT> Hello! How can I help you today?
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ <USER> Who are you?
13
+ <ASSISTANT> I am Veda, an AI programming assistant.
 
 
14
 
15
+ <USER> What is Python?
16
+ <ASSISTANT> Python is a high-level programming language known for its simplicity and readability.
17
 
18
+ <USER> Write hello world in python
19
+ <ASSISTANT> <CODE>print("Hello, World!")<ENDCODE>
 
 
 
 
 
 
20
 
21
+ <USER> Write a function to add numbers
22
+ <ASSISTANT> <CODE>
23
  def add(a, b):
24
  return a + b
25
+ </CODE>
26
 
27
+ <USER> Explain recursion
28
+ <ASSISTANT> Recursion is when a function calls itself to solve a problem. It needs a base case to stop.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  """
30
 
 
31
  class VedaTrainer:
32
+ def __init__(self):
33
+ self.tokenizer = VedaTokenizer(VOCAB_SIZE)
 
 
 
 
 
34
  self.model = None
35
 
36
+ def train(self, epochs=10, extra_data=""):
37
+ data = TRAINING_DATA + "\n" + extra_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  self.tokenizer.fit([data])
39
+ tokens = self.tokenizer.encode(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Create dataset
42
+ seqs = []
43
+ for i in range(0, len(tokens)-MAX_LENGTH, 50):
44
+ seqs.append(tokens[i:i+MAX_LENGTH+1])
45
+
46
+ import numpy as np
47
+ if not seqs: seqs = [tokens[:MAX_LENGTH+1]]
48
+ arr = np.array(seqs)
49
+ ds = tf.data.Dataset.from_tensor_slices((arr[:, :-1], arr[:, 1:])).batch(4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ self.model = VedaProgrammingLLM(self.tokenizer.vocabulary_size)
52
+ self.model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
53
+
54
+ # Build model
55
+ self.model(tf.zeros((1, MAX_LENGTH)))
56
+ self.model.fit(ds, epochs=epochs)
57
+
58
+ # Save
59
+ self.model.save_weights(os.path.join(MODEL_DIR, "weights.h5"))
60
+ self.tokenizer.save(os.path.join(MODEL_DIR, "tokenizer.json"))
61
+ with open(os.path.join(MODEL_DIR, "config.json"), 'w') as f:
62
+ json.dump(self.model.get_config(), f)
63
 
64
  if __name__ == "__main__":
65
+ VedaTrainer().train(epochs=20)