-
Notifications
You must be signed in to change notification settings - Fork 8
/
brokenhill.py
2523 lines (2087 loc) · 198 KB
/
brokenhill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/bin/env python3
script_name = "brokenhill.py"
script_version = "0.37"
script_date = "2024-11-15"
def get_logo():
result = " \n"
#result = "xxxxxxxxxxxxxxxxxxx|xxxxxxxxxxxxxxxxxxx||xxxxxxxxxxxxxxxxxxx|xxxxxxxxxxxxxxxxxxx\n"
result += " . . \n"
result += " .oO___________ ___________Oo. \n"
result += " . \____________________________________________/ . \n"
result += " | | \n"
result += " | | \n"
result += " | Broken Hill | \n"
result += " | | \n"
result += " | a tool for attacking LLMs, presented by Bishop Fox | \n"
result += " | | \n"
result += " | https://github.com/BishopFox/BrokenHill | \n"
result += " | | \n"
result += " | ________________________________ | \n"
result += " ' _________________/ \_________________ ' \n"
result += " '^O O^' \n"
result += " ' ' \n"
result += " \n"
return result
# def get_logo():
# result = " \n"
# #result = "xxxxxxxxxxxxxxxxxxx|xxxxxxxxxxxxxxxxxxx||xxxxxxxxxxxxxxxxxxx|xxxxxxxxxxxxxxxxxxx\n"
# result += " . . \n"
# result += " .oO_________________ _________________Oo. \n"
# result += " . \________________________________/ . \n"
# result += " | | \n"
# result += " | | \n"
# result += " | Broken Hill | \n"
# result += " | | \n"
# result += " | a tool for attacking LLMs, presented by Bishop Fox | \n"
# result += " | | \n"
# result += " | https://github.com/BishopFox/BrokenHill | \n"
# result += " | | \n"
# result += " | ____________________________________________ | \n"
# result += " ' ___________/ \___________ ' \n"
# result += " '^O O^' \n"
# result += " ' ' \n"
# result += " \n"
# return result
def get_script_description():
result = 'Performs a "Greedy Coordinate Gradient" (GCG) attack against various large language models (LLMs), as described in the paper "Universal and Transferable Adversarial Attacks on Aligned Language Models" by Andy Zou, Zifan Wang, Nicholas Carlini, Milad Nasr, J. Zico Kolter, and Matt Fredrikson, representing Carnegie Mellon University, the Center for AI Safety, Google DeepMind, and the Bosch Center for AI.'
result += "\n"
result += "Originally based on the demo.ipynb notebook and associated llm-attacks library from https://github.com/llm-attacks/llm-attacks"
result += "\n"
result += "Also incorporates gradient-sampling code and mellowmax loss function from nanoGCG - https://github.com/GraySwanAI/nanoGCG"
result += "\n"
result += "This tool created and all other post-fork changes to the associated library by Ben Lincoln, Bishop Fox."
result += "\n"
result += f"version {script_version}, {script_date}"
return result
def get_short_script_description():
result = 'Based on code and research by Andy Zou, Zifan Wang, Nicholas Carlini, Milad Nasr, J. Zico Kolter, and Matt Fredrikson.'
result += "\n"
result += "Also incorporates gradient-sampling code and mellowmax loss function from nanoGCG - https://github.com/GraySwanAI/nanoGCG"
result += "\n"
result += "This tool created and all other post-fork changes to the associated library by Ben Lincoln, Bishop Fox."
return result
import argparse
import base64
import copy
import datetime
# IMPORTANT: 'fastchat' is in the PyPi package 'fschat', not 'fastchat'!
#import fastchat.conversation
import fastchat as fschat
import fastchat.conversation as fschat_conversation
import gc
import locale
import json
import logging
import math
import numpy
import os
import pathlib
import psutil
import re
import shutil
import sys
import tempfile
import time
import torch
import torch.nn as nn
import torch.quantization as tq
import traceback
from llm_attacks_bishopfox.attack.attack_classes import AdversarialContent
from llm_attacks_bishopfox.attack.attack_classes import AdversarialContentList
from llm_attacks_bishopfox.attack.attack_classes import AdversarialContentPlacement
from llm_attacks_bishopfox.attack.attack_classes import AssociationRebuildException
from llm_attacks_bishopfox.attack.attack_classes import AttackInitializationException
from llm_attacks_bishopfox.attack.attack_classes import AttackParams
from llm_attacks_bishopfox.attack.attack_classes import AttackResultInfo
from llm_attacks_bishopfox.attack.attack_classes import AttackResultInfoCollection
from llm_attacks_bishopfox.attack.attack_classes import AttackResultInfoData
from llm_attacks_bishopfox.attack.attack_classes import BrokenHillMode
from llm_attacks_bishopfox.attack.attack_classes import BrokenHillRandomNumberGenerators
from llm_attacks_bishopfox.attack.attack_classes import BrokenHillResultData
from llm_attacks_bishopfox.attack.attack_classes import DecodingException
from llm_attacks_bishopfox.attack.attack_classes import EncodingException
from llm_attacks_bishopfox.attack.attack_classes import GenerationException
from llm_attacks_bishopfox.attack.attack_classes import InitialAdversarialContentCreationMode
from llm_attacks_bishopfox.attack.attack_classes import LossAlgorithm
from llm_attacks_bishopfox.attack.attack_classes import LossSliceMode
from llm_attacks_bishopfox.attack.attack_classes import LossThresholdException
from llm_attacks_bishopfox.attack.attack_classes import ModelDataFormatHandling
from llm_attacks_bishopfox.attack.attack_classes import MyCurrentMentalImageOfALargeValueShouldBeEnoughForAnyoneException
from llm_attacks_bishopfox.attack.attack_classes import OverallScoringFunction
from llm_attacks_bishopfox.attack.attack_classes import PersistableAttackState
from llm_attacks_bishopfox.attack.attack_classes import VolatileAttackState
from llm_attacks_bishopfox.attack.attack_classes import get_missing_pad_token_names
from llm_attacks_bishopfox.base.attack_manager import EmbeddingLayerNotFoundException
from llm_attacks_bishopfox.dumpster_fires.conversation_templates import SeparatorStyleConversionException
from llm_attacks_bishopfox.dumpster_fires.conversation_templates import ConversationTemplateSerializationException
from llm_attacks_bishopfox.dumpster_fires.trash_fire_tokens import TrashFireTokenException
from llm_attacks_bishopfox.dumpster_fires.trash_fire_tokens import get_decoded_token
from llm_attacks_bishopfox.dumpster_fires.trash_fire_tokens import get_decoded_tokens
from llm_attacks_bishopfox.jailbreak_detection.jailbreak_detection import LLMJailbreakDetector
from llm_attacks_bishopfox.jailbreak_detection.jailbreak_detection import LLMJailbreakDetectorRuleSet
from llm_attacks_bishopfox.json_serializable_object import JSONSerializationException
from llm_attacks_bishopfox.llms.large_language_models import LargeLanguageModelException
from llm_attacks_bishopfox.llms.large_language_models import LargeLanguageModelParameterException
from llm_attacks_bishopfox.logging import BrokenHillLogManager
from llm_attacks_bishopfox.logging import LoggingException
from llm_attacks_bishopfox.minimal_gcg.adversarial_content_utils import AdversarialContentManager
from llm_attacks_bishopfox.minimal_gcg.adversarial_content_utils import PromptGenerationException
from llm_attacks_bishopfox.minimal_gcg.adversarial_content_utils import register_custom_conversation_templates
from llm_attacks_bishopfox.minimal_gcg.opt_utils import GradientCreationException
from llm_attacks_bishopfox.minimal_gcg.opt_utils import GradientSamplingException
from llm_attacks_bishopfox.minimal_gcg.opt_utils import MellowmaxException
from llm_attacks_bishopfox.minimal_gcg.opt_utils import NullPaddingTokenException
from llm_attacks_bishopfox.minimal_gcg.opt_utils import PaddingException
from llm_attacks_bishopfox.minimal_gcg.opt_utils import get_adversarial_content_candidates
from llm_attacks_bishopfox.minimal_gcg.opt_utils import get_filtered_cands
from llm_attacks_bishopfox.minimal_gcg.opt_utils import get_logits
from llm_attacks_bishopfox.minimal_gcg.opt_utils import target_loss
from llm_attacks_bishopfox.minimal_gcg.opt_utils import token_gradients
from llm_attacks_bishopfox.statistics.statistical_tools import StatisticsException
from llm_attacks_bishopfox.teratogenic_tokens.language_names import HumanLanguageException
from llm_attacks_bishopfox.util.util_functions import BrokenHillFileIOException
from llm_attacks_bishopfox.util.util_functions import BrokenHillValueException
from llm_attacks_bishopfox.util.util_functions import FakeException
from llm_attacks_bishopfox.util.util_functions import PyTorchDevice
from llm_attacks_bishopfox.util.util_functions import comma_delimited_string_to_integer_array
from llm_attacks_bishopfox.util.util_functions import command_array_to_string
from llm_attacks_bishopfox.util.util_functions import delete_file
from llm_attacks_bishopfox.util.util_functions import get_broken_hill_state_file_name
from llm_attacks_bishopfox.util.util_functions import get_elapsed_time_string
from llm_attacks_bishopfox.util.util_functions import get_escaped_string
from llm_attacks_bishopfox.util.util_functions import get_file_content
from llm_attacks_bishopfox.util.util_functions import get_file_content_from_sys_argv
from llm_attacks_bishopfox.util.util_functions import get_log_level_names
from llm_attacks_bishopfox.util.util_functions import get_now
from llm_attacks_bishopfox.util.util_functions import get_random_token_id
from llm_attacks_bishopfox.util.util_functions import get_random_token_ids
from llm_attacks_bishopfox.util.util_functions import get_time_string
from llm_attacks_bishopfox.util.util_functions import load_json_from_file
from llm_attacks_bishopfox.util.util_functions import log_level_name_to_log_level
from llm_attacks_bishopfox.util.util_functions import numeric_string_to_float
from llm_attacks_bishopfox.util.util_functions import numeric_string_to_int
from llm_attacks_bishopfox.util.util_functions import safely_write_text_output_file
from llm_attacks_bishopfox.util.util_functions import str2bool
from llm_attacks_bishopfox.util.util_functions import update_elapsed_time_string
from llm_attacks_bishopfox.util.util_functions import verify_output_file_capability
from peft import PeftModel
from torch.quantization import quantize_dynamic
from torch.quantization.qconfig import float_qparams_weight_only_qconfig
logger = logging.getLogger(__name__)
SAFETENSORS_WEIGHTS_FILE_NAME = "adapter_model.safetensors"
# threshold for warning the user if the specified PyTorch device already has more than this percent of its memory reserved
# 0.1 = 10%
torch_device_reserved_memory_warning_threshold = 0.1
# Use the OS-level locale
locale.setlocale(locale.LC_ALL, '')
# Workaround for glitchy Protobuf code somewhere
# See https://stackoverflow.com/questions/75042153/cant-load-from-autotokenizer-from-pretrained-typeerror-duplicate-file-name
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"]="python"
# Workaround for overly-chatty-by-default PyTorch code
# all_loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
# for existing_logger in all_loggers:
# existing_logger.setLevel(logging.WARNING)
def check_pytorch_devices(attack_params):
all_devices = {}
devices_above_threshold = {}
if attack_params.using_cuda():
cuda_devices = PyTorchDevice.get_all_cuda_devices()
for i in range(0, len(cuda_devices)):
d = cuda_devices[i]
all_devices[d.device_name] = d
if d.total_memory_utilization > torch_device_reserved_memory_warning_threshold:
devices_above_threshold[d.device_name] = d
device_names = list(all_devices.keys())
if len(device_names) > 0:
device_names.sort()
message = f"Available PyTorch devices for the back-end in use:\n"
for dn in device_names:
d = all_devices[dn]
message += f"\t{d.device_name} - {d.device_display_name}\n"
message += f"\t\tTotal memory: {d.total_memory:n} byte(s)\n"
message += f"\t\tMemory in use across the entire device: {d.gpu_used_memory:n} byte(s)\n"
message += f"\t\tCurrent memory utilization for the device as a whole: {d.total_memory_utilization:.0%}\n"
logger.info(message)
above_threshold_device_names = list(devices_above_threshold.keys())
if len(above_threshold_device_names) > 0:
above_threshold_device_names.sort()
warning_message = f"The following PyTorch devices have more than {torch_device_reserved_memory_warning_threshold:.0%} of their memory reserved:\n"
for dn in above_threshold_device_names:
d = devices_above_threshold[dn]
warning_message += f"\t{d.device_name} ({d.device_display_name}): {d.total_memory_utilization:.0%}\n"
warning_message += f"If you encounter out-of-memory errors when using Broken Hill, consider suspending other processes that use GPU resources, to maximize the amount of memory available to PyTorch. For example, on Linux desktops with a GUI enabled, consider switching to a text-only console to suspend the display manager and free up the associated VRAM. On Debian, Ctrl-Alt-F2 switches to the second console, and Ctrl-Alt-F1 switches back to the default console.\n"
logger.warning(warning_message)
def main(attack_params, log_manager):
user_aborted = False
abnormal_termination = False
general_error_guidance_message = "\n\nIf this issue is due to a temporary or operator-correctable condition (such as another process using device memory that would otherwise be available to Broken Hill) you may be able to continue execution (after correcting any issues) using the instructions that Broken Hill should generate before exiting.\n"
bug_guidance_message = "\n\nIf this error occurred while using Broken Hill with an LLM in the list of officially tested models, using the recommended options for that model, is not discussed in the troubleshooting documentation, and is not the result of an operator-correctable condition (such as an invalid path, insufficient memory, etc.), please open an issue with the Broken Hill developers, including steps to reproduce the error.\nIf this error occurred while using Broken Hill with an LLM that is not in the list of tested models, you may submit a feature request to add support for the model.\nIf this error occurred while using Broken Hill with an LLM in the list of officially tested models, but *not* using the recommended options for that model, please try using the recommended options instead."
attack_state = VolatileAttackState()
attack_state.log_manager = log_manager
if attack_params.load_state_from_file:
state_file_dict = None
try:
state_file_dict = load_json_from_file(attack_params.load_state_from_file)
except Exception as e:
logger.critical(f"Could not load state JSON data from '{attack_params.load_state_from_file}': {e}\n{traceback.format_exc()}")
sys.exit(1)
# Loading the AttackParams directly from the saved state won't work, because then the user wouldn't be able to override them with other explicit command-line options.
# attack_params is already a merged version of whatever combination of sources was specified on the command line.
#merged_attack_params = attack_params.copy()
# Everything else is then loaded from the persistable state data.
attack_state.persistable = PersistableAttackState.apply_dict(attack_state.persistable, state_file_dict)
attack_state.restore_from_persistable_data()
if attack_state.persistable.broken_hill_version != script_version:
logger.warning(f"The state file '{attack_params.load_state_from_file}' was generated by Broken Hill version {attack_state.persistable.broken_hill_version}, but is being loaded using Broken Hill version {script_version}. If you need results that match as closely as possible, you should use Broken Hill version {attack_state.persistable.broken_hill_version} instead.")
# Finally, the AttackParams are replaced with the merged version generated earlier, by the first line outside this if block
#attack_state.persistable.attack_params = merged_attack_params
attack_state.persistable.attack_params = attack_params
attack_state.persistable.broken_hill_version = script_version
# Saving the options is done here instead of earlier so that they're the merged result of all possible option sources
if attack_state.persistable.attack_params.operating_mode == BrokenHillMode.SAVE_OPTIONS:
try:
options_data = attack_state.persistable.attack_params.to_dict()
json_options_data = json.dumps(options_data, indent=4)
safely_write_text_output_file(attack_state.persistable.attack_params.save_options_path, json_options_data)
logger.info(f"The current Broken Hill configuration has been written to '{attack_state.persistable.attack_params.save_options_path}'.")
except Exception as e:
logger.critical(f"Could not write the current Broken Hill configuration to '{attack_state.persistable.attack_params.save_options_path}': {e}\n{traceback.format_exc()}")
sys.exit(1)
sys.exit(0)
if attack_state.persistable.attack_params.torch_cuda_memory_history_file is not None and attack_state.persistable.attack_params.using_cuda():
torch.cuda.memory._record_memory_history()
if not attack_params.load_state_from_file:
attack_state.initialize_devices()
attack_state.initialize_random_number_generators()
attack_state.persistable.initialize_language_manager()
attack_state.initialize_jailbreak_detector()
if attack_state.persistable.attack_params.save_state:
# if there is not an existing state file, create a new one in the specified directory
create_new_state_file = True
if attack_state.persistable.attack_params.state_file is not None:
if attack_state.persistable.attack_params.overwrite_output and attack_state.persistable.attack_params.overwrite_existing_state:
create_new_state_file = False
if create_new_state_file:
attack_state.persistable.attack_params.state_file = os.path.join(attack_state.persistable.attack_params.state_directory, get_broken_hill_state_file_name(attack_state))
# only test write capability if there is not an existing state file, so it's not overwritten
if not os.path.isfile(attack_state.persistable.attack_params.state_file):
verify_output_file_capability(attack_state.persistable.attack_params.state_file, attack_state.persistable.attack_params.overwrite_output)
logger.info(f"State information for this attack will be stored in '{attack_state.persistable.attack_params.state_file}'.")
start_dt = get_now()
start_ts = get_time_string(start_dt)
logger.info(f"Starting at {start_ts}")
# Parameter validation, warnings, and errors
if attack_state.persistable.attack_params.using_cpu():
using_ok_format = False
if attack_state.persistable.attack_params.model_data_format_handling == ModelDataFormatHandling.DEFAULT:
using_ok_format = True
if attack_state.persistable.attack_params.model_data_format_handling == ModelDataFormatHandling.FORCE_BFLOAT16:
using_ok_format = True
if attack_state.persistable.attack_params.model_data_format_handling == ModelDataFormatHandling.FORCE_FLOAT32:
using_ok_format = True
if using_ok_format:
logger.warning(f"You are using a CPU device for processing. Depending on your hardware, the default model data type may cause degraded performance. If you encounter poor performance, try specifying --model-data-type float32 when launching Broken Hill, as long as your device has sufficient system RAM. Consult the documentation for further details.")
else:
logger.warning(f"You are using a CPU device for processing, but Broken Hill is configured to use an unsupported PyTorch dtype when loading the model. In particular, if 'float16' is specified, it will also greatly increase runtimes. If you encounter unusual behaviour, such as iteration times of 10 hours, or incorrect output from the model, try specifying --model-data-type default. Consult the documentation for further details.")
non_cuda_devices = attack_state.persistable.attack_params.get_non_cuda_devices()
if len(non_cuda_devices) > 0:
if attack_state.persistable.attack_params.model_data_format_handling == ModelDataFormatHandling.FORCE_FLOAT16:
logger.warning(f"Using the following device(s) with the 'float16' model data format is not recommended: {non_cuda_devices}. Using this format on non-CUDA devices will cause Broken Hill to run extremely slowly. Expect performance about 100 times slower than using 'float16' on CUDA hardware, for example, and about 10 times slower than using 'float32' on CPU hardware.")
ietf_tag_names = None
ietf_tag_data = None
try:
ietf_tag_names, ietf_tag_data = attack_state.persistable.language_manager.get_ietf_tags()
except Exception as e:
logger.critical(f"Could not load the human language data bundled with Broken Hill: {e}\n{traceback.format_exc()}")
sys.exit(1)
if attack_state.persistable.attack_params.operating_mode == BrokenHillMode.LIST_IETF_TAGS:
ietf_message = "Supported IETF language tags in this version:"
for i in range(0, len(ietf_tag_names)):
ietf_message = f"{ietf_message}\n{ietf_tag_names[i]}\t{ietf_tag_data[ietf_tag_names[i]]}"
logger.info(ietf_message)
sys.exit(0)
attack_state.persistable.overall_result_data.start_date_time = start_ts
attack_state.persistable.performance_data.collect_torch_stats(attack_state, is_key_snapshot_event = True, location_description = "before loading model and tokenizer")
try:
attack_state.load_model()
if attack_state.model_type_name == "MosaicGPT":
logger.warning("This model is of type MosaicGPT. At the time this version of Broken Hill was made, MosaicGPT did not support the 'inputs_embeds' keyword when calling the forward method. If that is still the case when you are reading this message, Broken Hill will likely crash during the GCG attack.")
if attack_state.model_type_name == "GPTNeoXForCausalLM":
warn_about_bad_gpt_neox_weight_type = True
if attack_state.persistable.attack_params.model_data_format_handling == ModelDataFormatHandling.DEFAULT:
warn_about_bad_gpt_neox_weight_type = False
# still need to validate that AUTO will actually work
if attack_state.persistable.attack_params.model_data_format_handling == ModelDataFormatHandling.AUTO:
warn_about_bad_gpt_neox_weight_type = False
if attack_state.persistable.attack_params.model_data_format_handling == ModelDataFormatHandling.FORCE_FLOAT32:
warn_about_bad_gpt_neox_weight_type = False
logger.warning("This model is of type GPTNeoXForCausalLM. At the time this version of Broken Hill was made, GPT-NeoX did not perform correctly in PyTorch/Transformers when loaded with weights in float16 format, possibly any other dtype besides float32. If you encounter very long processing times or incorrect output (such as the LLM responding with only the <|endoftext|> token), try using one of the following options instead of your current --model-data-type selection:\n--model-data-type default\n--model-data-type auto\n--model-data-type float32")
# only build the token allowlist and denylist if this is a new run.
# If the state has been loaded, the lists are already populated.
# So is the tensor form of the list.
if not attack_params.load_state_from_file:
attack_state.build_token_allow_and_denylists()
original_model_size = 0
if attack_state.persistable.attack_params.display_model_size:
# Only perform this work if the results will actually be logged, to avoid unnecessary performance impact
# Assume the same comment for all instances of this pattern
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Determining model size.")
original_model_size = get_model_size(attack_state.model)
logger.info(f"Model size: {original_model_size}")
# This code still doesn't do anything useful, so don't get your hopes up!
attack_state.apply_model_quantization()
attack_state.apply_model_dtype_conversion()
# if attack_state.persistable.attack_params.conversion_dtype:
# #logger.debug(f"converting model dtype to {attack_state.persistable.attack_params.conversion_dtype}.")
# attack_state.model = attack_state.model.to(attack_state.persistable.attack_params.conversion_dtype)
if attack_state.persistable.attack_params.quantization_dtype or attack_state.persistable.attack_params.enable_static_quantization or attack_state.persistable.attack_params.conversion_dtype:
logger.warning(f"You've enabled quantization and/or type conversion, which are unlikely to work for the foreseeable future due to PyTorch limitations. Please see the comments in the source code for Broken Hill.")
if attack_state.persistable.attack_params.display_model_size:
quantized_model_size = get_model_size(attack_state.model)
size_factor = float(quantized_model_size) / float(original_model_size) * 100.0
size_factor_formatted = f"{size_factor:.2f}%"
logger.info(f"Model size after reduction: {quantized_model_size} ({size_factor_formatted} of original size)")
register_custom_conversation_templates(attack_state.persistable.attack_params)
attack_state.load_conversation_template()
attack_state.ignite_trash_fire_token_treasury()
# If the state was loaded from a file, the initial adversarial content information should already be populated.
if not attack_params.load_state_from_file:
attack_state.create_initial_adversarial_content()
attack_state.check_for_adversarial_content_token_problems()
logger.info(f"Initial adversarial content: {attack_state.persistable.initial_adversarial_content.get_full_description()}")
attack_state.persistable.current_adversarial_content = attack_state.persistable.initial_adversarial_content.copy()
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = "before creating adversarial content manager")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"creating adversarial content manager.")
attack_state.adversarial_content_manager = AdversarialContentManager(attack_state = attack_state,
conv_template = attack_state.conversation_template,
#adversarial_content = attack_state.persistable.initial_adversarial_content.copy(),
adversarial_content = attack_state.persistable.current_adversarial_content.copy(),
trash_fire_tokens = attack_state.trash_fire_token_treasury)
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = "after creating adversarial content manager")
#import pdb; pdb.Pdb(nosigint=True).set_trace()
attack_state.persistable.original_new_adversarial_value_candidate_count = attack_state.persistable.attack_params.new_adversarial_value_candidate_count
attack_state.persistable.original_topk = attack_state.persistable.attack_params.topk
# Keep this out until things like denied tokens are file paths instead of inline, to keep from infecting result files with bad words
#attack_state.persistable.overall_result_data.attack_params = attack_state.persistable.attack_params
attack_state.test_conversation_template()
attack_state.perform_jailbreak_tests()
if attack_state.persistable.attack_params.operating_mode == BrokenHillMode.GCG_ATTACK_SELF_TEST:
logger.info(f"Broken Hill has completed all self-test operations and will now exit.")
end_ts = get_time_string(get_now())
attack_state.persistable.overall_result_data.end_date_time = end_ts
attack_state.write_output_files()
sys.exit(0)
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = "before creating embedding_matrix")
# If loading the state from a file, reload the RNG states right before starting the loop, in case something during initialization has messed with them
if attack_params.load_state_from_file:
attack_state.restore_random_number_generator_states()
logger.info(f"Starting main loop")
while attack_state.persistable.main_loop_iteration_number < attack_state.persistable.attack_params.max_iterations:
display_iteration_number = attack_state.persistable.main_loop_iteration_number + 1
#attack_state.persistable.random_number_generator_states = attack_state.random_number_generators.get_current_states()
is_success = False
if user_aborted:
break
else:
try:
attack_results_current_iteration = AttackResultInfoCollection()
attack_results_current_iteration.iteration_number = attack_state.persistable.main_loop_iteration_number + 1
iteration_start_dt = get_now()
current_ts = get_time_string(iteration_start_dt)
current_elapsed_string = get_elapsed_time_string(start_dt, iteration_start_dt)
logger.info(f"{current_ts} - Main loop iteration {display_iteration_number} of {attack_state.persistable.attack_params.max_iterations} - elapsed time {current_elapsed_string} - successful attack count: {attack_state.persistable.successful_attack_count}")
attack_state.persistable.overall_result_data.end_date_time = current_ts
attack_state.persistable.overall_result_data.elapsed_time_string = current_elapsed_string
attack_state.persistable.performance_data.collect_torch_stats(attack_state, is_key_snapshot_event = True, location_description = f"beginning of main loop iteration {display_iteration_number}")
attack_data_previous_iteration = None
if attack_state.persistable.main_loop_iteration_number > 0:
#attack_data_previous_iteration = attack_data[len(attack_data) - 1]
attack_data_previous_iteration = attack_state.persistable.overall_result_data.attack_results[len(attack_state.persistable.overall_result_data.attack_results) - 1]
attack_state.persistable.tested_adversarial_content.append_if_new(attack_state.persistable.current_adversarial_content)
# TKTK: split the actual attack step out into a separate subclass of an attack class.
# Maybe TokenPermutationAttack => GreedyCoordinateGradientAttack?
# if this is not the first iteration, and the user has enabled emulation of the original attack, encode the current string, then use those IDs for this round instead of persisting everything in token ID format
if attack_state.persistable.main_loop_iteration_number > 0:
if attack_state.persistable.attack_params.reencode_adversarial_content_every_iteration:
reencoded_token_ids = attack_state.tokenizer.encode(attack_state.persistable.current_adversarial_content.as_string)
attack_state.persistable.current_adversarial_content = AdversarialContent.from_token_ids(attack_state, attack_state.trash_fire_token_treasury, reencoded_token_ids)
attack_state.adversarial_content_manager.adversarial_content = attack_state.persistable.current_adversarial_content
# Step 1. Encode user prompt (behavior + adv suffix) as tokens and return token ids.
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - before creating input_id_data")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Calling get_input_ids with attack_state.persistable.current_adversarial_content = '{attack_state.persistable.current_adversarial_content.get_short_description()}'")
input_id_data = attack_state.adversarial_content_manager.get_prompt(adversarial_content = attack_state.persistable.current_adversarial_content, force_python_tokenizer = attack_state.persistable.attack_params.force_python_tokenizer)
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after creating input_id_data")
# Only perform this work if the results will actually be logged
decoded_loss_slice = None
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
if attack_state.persistable.attack_params.generate_debug_logs_requiring_extra_tokenizer_calls:
decoded_input_tokens = get_decoded_tokens(attack_state, input_id_data.input_token_ids)
decoded_full_prompt_token_ids = get_decoded_tokens(attack_state, input_id_data.full_prompt_token_ids)
decoded_control_slice = get_decoded_tokens(attack_state, input_id_data.full_prompt_token_ids[input_id_data.slice_data.control])
decoded_target_slice = get_decoded_tokens(attack_state, input_id_data.full_prompt_token_ids[input_id_data.slice_data.target_output])
decoded_loss_slice = get_decoded_tokens(attack_state, input_id_data.full_prompt_token_ids[input_id_data.slice_data.loss])
logger.debug(f"decoded_input_tokens = '{decoded_input_tokens}'\n decoded_full_prompt_token_ids = '{decoded_full_prompt_token_ids}'\n decoded_control_slice = '{decoded_control_slice}'\n decoded_target_slice = '{decoded_target_slice}'\n decoded_loss_slice = '{decoded_loss_slice}'\n input_id_data.slice_data.control = '{input_id_data.slice_data.control}'\n input_id_data.slice_data.target_output = '{input_id_data.slice_data.target_output}'\n input_id_data.slice_data.loss = '{input_id_data.slice_data.loss}'\n input_id_data.input_token_ids = '{input_id_data.input_token_ids}'\n input_id_data.full_prompt_token_ids = '{input_id_data.full_prompt_token_ids}'")
decoded_loss_slice_string = get_escaped_string(attack_state.tokenizer.decode(input_id_data.full_prompt_token_ids[input_id_data.slice_data.loss]))
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - before creating input_ids")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Converting input IDs to device")
input_ids = input_id_data.get_input_ids_as_tensor().to(attack_state.model_device)
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"input_ids after conversion = '{input_ids}'")
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after creating input_ids")
input_id_data_gcg_ops = input_id_data
input_ids_gcg_ops = input_ids
if attack_state.persistable.attack_params.ignore_prologue_during_gcg_operations:
conv_template_gcg_ops = attack_state.conversation_template.copy()
conv_template_gcg_ops.system_message=""
conv_template_gcg_ops.messages = []
adversarial_content_manager_gcg_ops = AdversarialContentManager(attack_state = attack_state,
conv_template = conv_template_gcg_ops,
adversarial_content = attack_state.persistable.current_adversarial_content.copy(),
trash_fire_tokens = attack_state.trash_fire_token_treasury)
input_id_data_gcg_ops = adversarial_content_manager_gcg_ops.get_prompt(adversarial_content = attack_state.persistable.current_adversarial_content, force_python_tokenizer = attack_state.persistable.attack_params.force_python_tokenizer)
input_ids_gcg_ops = input_id_data_gcg_ops.get_input_ids_as_tensor().to(attack_state.model_device)
best_new_adversarial_content = None
# preserve the RNG states because the code in this section is likely to reset them a bunch of times
rng_states = attack_state.random_number_generators.get_current_states()
# declare these here so they can be cleaned up later
coordinate_gradient = None
# during the first iteration, do not generate variations - test the value that was given
if attack_state.persistable.main_loop_iteration_number == 0:
logger.info(f"Testing initial adversarial value '{attack_state.persistable.current_adversarial_content.get_short_description()}'")
else:
# Step 2. Compute Coordinate Gradient
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - before creating coordinate gradient")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Computing coordinate gradient")
try:
coordinate_gradient = token_gradients(attack_state,
input_ids_gcg_ops,
input_id_data_gcg_ops)
except GradientCreationException as e:
raise GradientCreationException(f"Attempting to generate a coordinate gradient failed: {e}. Please contact a developer with steps to reproduce this issue if it has not already been reported.\n{traceback.format_exc()}")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"coordinate_gradient.shape[0] = {coordinate_gradient.shape[0]}")
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after creating coordinate gradient")
# if isinstance(random_generator_gradient, type(None)):
# random_generator_gradient = torch.Generator(device = coordinate_gradient.device).manual_seed(attack_state.persistable.attack_params.torch_manual_seed)
# Step 3. Sample a batch of new tokens based on the coordinate gradient.
# Notice that we only need the one that minimizes the loss.
with torch.no_grad():
got_candidate_list = False
new_adversarial_candidate_list = None
new_adversarial_candidate_list_filtered = None
losses = None
best_new_adversarial_content_id = None
best_new_adversarial_content = None
current_loss = None
current_loss_as_float = None
# BEGIN: wrap in loss threshold check
candidate_list_meets_loss_threshold = False
num_iterations_without_acceptable_loss = 0
# store the best value from each attempt in case no value is found that meets the threshold
best_failed_attempts = AdversarialContentList()
while not candidate_list_meets_loss_threshold:
got_candidate_list = False
while not got_candidate_list:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Generating new adversarial content candidates")
new_adversarial_candidate_list = None
try:
new_adversarial_candidate_list = get_adversarial_content_candidates(attack_state,
coordinate_gradient,
not_allowed_tokens = attack_state.get_token_denylist_as_cpu_tensor())
except GradientSamplingException as e:
raise GradientSamplingException(f"Attempting to generate a new set of candidate adversarial data failed: {e}. Please contact a developer with steps to reproduce this issue if it has not already been reported.\n{traceback.format_exc()}")
except RuntimeError as e:
raise GradientSamplingException(f"Attempting to generate a new set of candidate adversarial data failed with a low-level error: {e}. This is typically caused by excessive or conflicting candidate-filtering options. For example, the operator may have specified a regular expression filter that rejects long strings, but also specified a long initial adversarial value. This error is unrecoverable. If you believe the error was not due to excessive/conflicting filtering options, please submit an issue.\n{traceback.format_exc()}")
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - before getting filtered candidates")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"new_adversarial_candidate_list: {new_adversarial_candidate_list.adversarial_content}")
# Note: I'm leaving this explanation here for historical reference
# Step 3.3 This step ensures all adversarial candidates have the same number of tokens.
# This step is necessary because tokenizers are not invertible so Encode(Decode(tokens)) may produce a different tokenization.
# We ensure the number of token remains [constant -Ben] to prevent the memory keeps growing and run into OOM.
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Getting filtered adversarial content candidates")
new_adversarial_candidate_list_filtered = get_filtered_cands(attack_state, new_adversarial_candidate_list, filter_cand = True)
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after getting filtered candidates")
if len(new_adversarial_candidate_list_filtered.adversarial_content) > 0:
got_candidate_list = True
else:
# try to find a way to increase the number of options available
something_has_changed = False
standard_explanation_intro = "The attack has failed to generate any adversarial values at this iteration that meet the specified filtering criteria and have not already been tested."
standard_explanation_outro = "You can try specifying larger values for --max-batch-size-new-adversarial-tokens and/or --max-topk to avoid this error, or enabling --add-token-when-no-candidates-returned and/or --delete-token-when-no-candidates-returned if they are not already enabled."
if attack_state.persistable.attack_params.add_token_when_no_candidates_returned:
token_count_limited = True
if isinstance(attack_state.persistable.attack_params.candidate_filter_tokens_max, type(None)):
token_count_limited = False
if token_count_limited:
if len(attack_state.persistable.current_adversarial_content.token_ids) < attack_state.persistable.attack_params.candidate_filter_tokens_max:
token_count_limited = False
current_short_description = attack_state.persistable.current_adversarial_content.get_short_description()
if token_count_limited:
logger.warning(f"{standard_explanation_intro} The option to add an additional token is enabled, but the current adversarial content {current_short_description} is already at the limit of {attack_state.persistable.attack_params.candidate_filter_tokens_max} tokens.")
else:
attack_state.persistable.current_adversarial_content.duplicate_random_token(numpy_random_generator, attack_state.tokenizer)
new_short_description = attack_state.persistable.current_adversarial_content.get_short_description()
something_has_changed = True
logger.info(f"{standard_explanation_intro} Because the option to add an additional token is enabled, the current adversarial content has been modified from {current_short_description} to {new_short_description}.")
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"The option to add an additional token is disabled.")
if not something_has_changed:
if attack_state.persistable.attack_params.delete_token_when_no_candidates_returned:
token_count_limited = True
minimum_token_count = 1
if isinstance(attack_state.persistable.attack_params.candidate_filter_tokens_min, type(None)):
token_count_limited = False
else:
if attack_state.persistable.attack_params.candidate_filter_tokens_min > 1:
minimum_token_count = attack_state.persistable.attack_params.candidate_filter_tokens_min
if len(attack_state.persistable.current_adversarial_content.token_ids) > attack_state.persistable.attack_params.candidate_filter_tokens_min:
token_count_limited = False
if not token_count_limited:
if len(attack_state.persistable.current_adversarial_content.token_ids) < 2:
token_count_limited = True
current_short_description = attack_state.persistable.current_adversarial_content.get_short_description()
if token_count_limited:
logger.warning(f"{standard_explanation_intro} The option to delete a random token is enabled, but the current adversarial content {current_short_description} is already at the minimum of {minimum_token_count} token(s).")
else:
attack_state.persistable.current_adversarial_content.delete_random_token(numpy_random_generator, attack_state.tokenizer)
new_short_description = attack_state.persistable.current_adversarial_content.get_short_description()
something_has_changed = True
logger.info(f"{standard_explanation_intro} Because the option to delete a random token is enabled, the current adversarial content has been modified from {current_short_description} to {new_short_description}.")
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"The option to delete a random token is disabled.")
if not something_has_changed:
new_new_adversarial_value_candidate_count = attack_state.persistable.attack_params.new_adversarial_value_candidate_count + attack_state.persistable.original_new_adversarial_value_candidate_count
increase_new_adversarial_value_candidate_count = True
if not isinstance(attack_state.persistable.attack_params.max_new_adversarial_value_candidate_count, type(None)):
if new_new_adversarial_value_candidate_count > attack_state.persistable.attack_params.max_new_adversarial_value_candidate_count:
new_new_adversarial_value_candidate_count = attack_state.persistable.attack_params.max_new_adversarial_value_candidate_count
if new_new_adversarial_value_candidate_count <= attack_state.persistable.attack_params.new_adversarial_value_candidate_count:
increase_new_adversarial_value_candidate_count = False
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"new_new_adversarial_value_candidate_count > attack_state.persistable.attack_params.new_adversarial_value_candidate_count.")
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"new_new_adversarial_value_candidate_count <= attack_state.persistable.attack_params.max_new_adversarial_value_candidate_count.")
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"attack_state.persistable.attack_params.max_new_adversarial_value_candidate_count is None.")
if increase_new_adversarial_value_candidate_count:
logger.warning(f"{standard_explanation_intro} This may be due to excessive post-generation filtering options. The --batch-size-new-adversarial-tokens value is being increased from {attack_state.persistable.attack_params.new_adversarial_value_candidate_count} to {new_new_adversarial_value_candidate_count} to increase the number of candidate values. {standard_explanation_outro}")
attack_state.persistable.attack_params.new_adversarial_value_candidate_count = new_new_adversarial_value_candidate_count
something_has_changed = True
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Not increasing the --batch-size-new-adversarial-tokens value.")
if not something_has_changed:
new_topk = attack_state.persistable.attack_params.topk + attack_state.persistable.original_topk
increase_topk = True
if not isinstance(attack_state.persistable.attack_params.max_topk, type(None)):
if new_topk > attack_state.persistable.attack_params.max_topk:
new_topk = attack_state.persistable.attack_params.max_topk
if new_topk <= attack_state.persistable.attack_params.topk:
increase_topk = False
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"new_topk > attack_state.persistable.attack_params.topk.")
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"new_topk <= attack_state.persistable.attack_params.max_topk.")
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"attack_state.persistable.attack_params.max_topk is None.")
if increase_topk:
logger.warning(f"{standard_explanation_intro} This may be due to excessive post-generation filtering options. The --topk value is being increased from {attack_state.persistable.attack_params.topk} to {new_topk} to increase the number of candidate values. {standard_explanation_outro}")
attack_state.persistable.attack_params.topk = new_topk
something_has_changed = True
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Not increasing the --topk value.")
if not something_has_changed:
raise GradientSamplingException(f"{standard_explanation_intro} This may be due to excessive post-generation filtering options. Because the 'topk' value has already reached or exceeded the specified maximum ({attack_state.persistable.attack_params.max_topk}), and no other options for increasing the number of potential candidates is possible in the current configuration, Broken Hill will now exit. {standard_explanation_outro}\n{traceback.format_exc()}")
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after getting finalized filtered candidates")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"new_adversarial_candidate_list_filtered: '{new_adversarial_candidate_list_filtered.to_dict()}'")
# Step 3.4 Compute loss on these candidates and take the argmin.
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Getting logits")
logits, ids = get_logits(attack_state,
input_ids = input_ids_gcg_ops,
adversarial_content = attack_state.persistable.current_adversarial_content,
adversarial_candidate_list = new_adversarial_candidate_list_filtered,
return_ids = True)
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after getting logits")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Calculating target loss")
losses = target_loss(attack_state, logits, ids, input_id_data_gcg_ops)
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after getting loss values")
# get rid of logits and ids immediately to save device memory, as it's no longer needed after the previous operation
# This frees about 1 GiB of device memory for a 500M model on CPU or CUDA
del logits
del ids
gc.collect()
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after deleting logits and ids and running gc.collect")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"losses = {losses}")
logger.debug(f"Getting losses argmin")
best_new_adversarial_content_id = losses.argmin()
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after getting best new adversarial content ID")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"best_new_adversarial_content_id = {best_new_adversarial_content_id}")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Setting best new adversarial content")
best_new_adversarial_content = new_adversarial_candidate_list_filtered.adversarial_content[best_new_adversarial_content_id].copy()
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after getting best new adversarial content")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Getting current loss")
current_loss = losses[best_new_adversarial_content_id]
del losses
gc.collect()
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after deleting losses and running gc.collect")
current_loss_as_float = None
try:
current_loss_as_float = float(f"{current_loss.detach().to(torch.float32).cpu().numpy()}")
except Exception as e:
logger.error(f"Could not convert the current loss value '{current_loss}' to a floating-point number: {e}\n{traceback.format_exc()}\nThe value 100.0 will be used instead.")
current_loss_as_float = 100.0
best_new_adversarial_content.original_loss = current_loss_as_float
if isinstance(attack_state.persistable.attack_params.required_loss_threshold, type(None)) or attack_state.persistable.main_loop_iteration_number == 0:
candidate_list_meets_loss_threshold = True
else:
if attack_data_previous_iteration is None:
candidate_list_meets_loss_threshold = True
else:
if isinstance(attack_data_previous_iteration.loss, type(None)):
candidate_list_meets_loss_threshold = True
if not candidate_list_meets_loss_threshold:
if best_new_adversarial_content.original_loss <= (attack_data_previous_iteration.loss + attack_state.persistable.attack_params.required_loss_threshold):
candidate_list_meets_loss_threshold = True
if not candidate_list_meets_loss_threshold:
num_iterations_without_acceptable_loss += 1
best_failed_attempts.append_if_new(best_new_adversarial_content)
loss_attempt_stats_message = f"{num_iterations_without_acceptable_loss} unsuccessful attempt(s) to generate a list of random candidates that has at least one candidate with a loss lower than {(attack_data_previous_iteration.loss + attack_state.persistable.attack_params.required_loss_threshold)}"
if not isinstance(attack_state.persistable.attack_params.required_loss_threshold, type(None)):
if attack_state.persistable.attack_params.required_loss_threshold != 0.0:
loss_attempt_stats_message += f" (previous loss of {attack_data_previous_iteration.loss} plus the specified threshold value {attack_state.persistable.attack_params.required_loss_threshold})"
loss_attempt_stats_message += f". Best value during this attempt was {current_loss_as_float}."
logger.warning(loss_attempt_stats_message)
if not isinstance(attack_state.persistable.attack_params.loss_threshold_max_attempts, type(None)):
if num_iterations_without_acceptable_loss >= attack_state.persistable.attack_params.loss_threshold_max_attempts:
loss_attempt_result_message = f"{num_iterations_without_acceptable_loss} unsuccessful attempt(s) has reached the limit of {attack_state.persistable.attack_params.loss_threshold_max_attempts} attempts."
if attack_state.persistable.attack_params.exit_on_loss_threshold_failure:
loss_attempt_result_message += " Broken Hill has been configured to exit when this condition occurs."
raise LossThresholdException(loss_attempt_result_message)
else:
best_new_adversarial_content = best_failed_attempts.get_content_with_lowest_loss()
candidate_list_meets_loss_threshold = True
loss_attempt_result_message += f" Broken Hill has been configured to use the adversarial content with the lowest loss discovered during this iteration when this condition occurs. Out of {len(best_failed_attempts.adversarial_content)} unique set(s) of tokens discovered during this iteration, the lowest loss value was {best_new_adversarial_content.original_loss} versus the previous loss of {attack_data_previous_iteration.loss}"
if not isinstance(attack_state.persistable.attack_params.required_loss_threshold, type(None)):
if attack_state.persistable.attack_params.required_loss_threshold != 0.0:
loss_attempt_result_message += f" and threshold {attack_state.persistable.attack_params.required_loss_threshold}"
loss_attempt_result_message += ". The adversarial content with that loss value will be used."
logger.warning(loss_attempt_result_message)
# END: wrap in loss threshold check
# Update the running attack_state.persistable.current_adversarial_content with the best candidate
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - before updating adversarial value")
logger.info(f"Updating adversarial value to the best value out of the new permutation list and testing it.\nWas: {attack_state.persistable.current_adversarial_content.get_short_description()} ({len(attack_state.persistable.current_adversarial_content.token_ids)} tokens)\nNow: {best_new_adversarial_content.get_short_description()} ({len(best_new_adversarial_content.token_ids)} tokens)")
attack_state.persistable.current_adversarial_content = best_new_adversarial_content
logger.info(f"Loss value for the new adversarial value in relation to '{decoded_loss_slice_string}'\nWas: {attack_data_previous_iteration.loss}\nNow: {attack_state.persistable.current_adversarial_content.original_loss}")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
if attack_state.persistable.attack_params.generate_debug_logs_requiring_extra_tokenizer_calls:
# this is in this block because the variable will be None if the previous check of the same type was bypassed
logger.debug(f"decoded_loss_slice = '{decoded_loss_slice}'")
logger.debug(f"input_id_data.full_prompt_token_ids[input_id_data.slice_data.loss] = '{input_id_data.full_prompt_token_ids[input_id_data.slice_data.loss]}'")
logger.debug(f"input_id_data_gcg_ops.full_prompt_token_ids[input_id_data_gcg_ops.slice_data.loss] = '{input_id_data_gcg_ops.full_prompt_token_ids[input_id_data_gcg_ops.slice_data.loss]}'")
#attack_results_current_iteration.loss = current_loss_as_float
attack_results_current_iteration.loss = attack_state.persistable.current_adversarial_content.original_loss
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - before creating best_new_adversarial_content_input_token_id_data")
best_new_adversarial_content_input_token_id_data = attack_state.adversarial_content_manager.get_prompt(adversarial_content = attack_state.persistable.current_adversarial_content, force_python_tokenizer = attack_state.persistable.attack_params.force_python_tokenizer)
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after creating best_new_adversarial_content_input_token_id_data")
# preserve the RNG states because the code in this section is likely to reset them a bunch of times
# they're preserved twice because this is inside a block that may not occur
# but if they're altered, it will be in the next section
rng_states = attack_state.random_number_generators.get_current_states()
attack_results_current_iteration.adversarial_content = attack_state.persistable.current_adversarial_content.copy()
# BEGIN: do for every random seed
prng_seed_index = -1
for randomized_test_number in range(0, attack_state.persistable.attack_params.random_seed_comparisons + 1):
prng_seed_index += 1
attack_data_current_iteration = AttackResultInfo()
attack_data_current_iteration.numpy_random_seed = attack_state.persistable.attack_params.numpy_random_seed
attack_data_current_iteration.torch_manual_seed = attack_state.persistable.attack_params.torch_manual_seed
attack_data_current_iteration.torch_cuda_manual_seed_all = attack_state.persistable.attack_params.torch_cuda_manual_seed_all
current_temperature = attack_state.persistable.attack_params.model_temperature_range_begin
# For the first run, leave the model in its default do_sample configuration
do_sample = False
if randomized_test_number == 0:
attack_data_current_iteration.is_canonical_result = True
attack_results_current_iteration.set_values(attack_state, best_new_adversarial_content_input_token_id_data.full_prompt_token_ids, best_new_adversarial_content_input_token_id_data.get_user_input_token_ids())
else:
if randomized_test_number == attack_state.persistable.attack_params.random_seed_comparisons or attack_state.persistable.attack_params.model_temperature_range_begin == attack_state.persistable.attack_params.model_temperature_range_end:
current_temperature = attack_state.persistable.attack_params.model_temperature_range_end
else:
current_temperature = attack_state.persistable.attack_params.model_temperature_range_begin + (((attack_state.persistable.attack_params.model_temperature_range_end - attack_state.persistable.attack_params.model_temperature_range_begin) / float(attack_state.persistable.attack_params.random_seed_comparisons) * randomized_test_number))
# For all other runs, enable do_sample to randomize results
do_sample = True
# Pick the next random seed that's not equivalent to any of the initial values
got_random_seed = False
while not got_random_seed:
random_seed = attack_state.random_seed_values[prng_seed_index]
seed_already_used = False
if random_seed == attack_state.persistable.attack_params.numpy_random_seed:
seed_already_used = True
if random_seed == attack_state.persistable.attack_params.torch_manual_seed:
seed_already_used = True
if random_seed == attack_state.persistable.attack_params.torch_cuda_manual_seed_all:
seed_already_used = True
if seed_already_used:
prng_seed_index += 1
len_random_seed_values = len(attack_state.random_seed_values)
if prng_seed_index > len_random_seed_values:
raise MyCurrentMentalImageOfALargeValueShouldBeEnoughForAnyoneException(f"Exceeded the number of random seeds available({len_random_seed_values}).")
else:
got_random_seed = True
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Temporarily setting all random seeds to {random_seed} to compare results")
numpy.random.seed(random_seed)
torch.manual_seed(random_seed)
if attack_state.persistable.attack_params.using_cuda():
torch.cuda.manual_seed_all(random_seed)
attack_data_current_iteration.numpy_random_seed = random_seed
attack_data_current_iteration.torch_manual_seed = random_seed
attack_data_current_iteration.torch_cuda_manual_seed_all = random_seed
attack_data_current_iteration.temperature = current_temperature
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - before checking for jailbreak success")
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Checking for successful jailbreak")
is_success, jailbreak_check_data, jailbreak_check_generation_results = attack_state.check_for_attack_success(best_new_adversarial_content_input_token_id_data,
current_temperature,
do_sample = do_sample)
attack_state.persistable.performance_data.collect_torch_stats(attack_state, location_description = f"main loop iteration {display_iteration_number} - after checking for jailbreak success")
if is_success:
if attack_data_current_iteration.is_canonical_result:
attack_data_current_iteration.canonical_llm_jailbroken = True
attack_data_current_iteration.jailbreak_detected = True
attack_results_current_iteration.jailbreak_detection_count += 1
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Passed:{is_success}\nCurrent best new adversarial content: '{attack_state.persistable.current_adversarial_content.get_short_description()}'")
full_output_dataset_name = "full_output"
jailbreak_check_dataset_name = "jailbreak_check"
if attack_state.persistable.attack_params.display_full_failed_output:
jailbreak_check_dataset_name = full_output_dataset_name
attack_data_current_iteration.result_data_sets[jailbreak_check_dataset_name] = jailbreak_check_data
# only generate full output if it hasn't already just been generated
if not attack_state.persistable.attack_params.display_full_failed_output and is_success:
full_output_data = AttackResultInfoData()
# Note: set random seeds for randomized variations where do_sample is True so that full output begins with identical output to shorter version
if do_sample:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Temporarily setting all random seeds to {random_seed} to generate full output")
numpy.random.seed(random_seed)
torch.manual_seed(random_seed)
if attack_state.persistable.attack_params.using_cuda():
torch.cuda.manual_seed_all(random_seed)
generation_results = attack_state.generate(best_new_adversarial_content_input_token_id_data, current_temperature, do_sample = do_sample, generate_full_output = True)
full_output_data.set_values(attack_state, generation_results.max_new_tokens, generation_results.output_token_ids, generation_results.output_token_ids_output_only)
attack_data_current_iteration.result_data_sets[full_output_dataset_name] = full_output_data
attack_results_current_iteration.results.append(attack_data_current_iteration)
# END: do for every random seed
# restore the RNG states
attack_state.random_number_generators.set_states(rng_states)
attack_results_current_iteration.update_unique_output_values()
iteration_status_message = f"Status:\n"
iteration_status_message = f"{iteration_status_message}Current input string:\n---\n{attack_results_current_iteration.decoded_user_input_string}\n---\n"
iteration_status_message = f"{iteration_status_message}Successful jailbreak attempts detected: {attack_results_current_iteration.jailbreak_detection_count}."
if attack_results_current_iteration.canonical_llm_jailbroken:
iteration_status_message = f"{iteration_status_message} Canonical LLM instance was jailbroken."
else:
iteration_status_message= f"{iteration_status_message} Canonical LLM instance was not jailbroken."
iteration_status_message = f"{iteration_status_message}\n{attack_results_current_iteration.unique_result_count} unique output(s) generated during testing:\n"
for uov_string in attack_results_current_iteration.unique_results.keys():
uov_count = attack_results_current_iteration.unique_results[uov_string]
iteration_status_message = f"{iteration_status_message}--- {uov_count} occurrence(s): ---\n"
iteration_status_message = f"{iteration_status_message}{uov_string}\n"
iteration_status_message = f"{iteration_status_message}---\n"
iteration_status_message = f"{iteration_status_message}Current best new adversarial content: {attack_state.persistable.current_adversarial_content.get_short_description()}"
logger.info(iteration_status_message)
# TKTK: maybe make this a threshold
if attack_results_current_iteration.jailbreak_detection_count > 0:
attack_state.persistable.successful_attack_count += 1
iteration_end_dt = get_now()
iteration_elapsed = iteration_end_dt - iteration_start_dt
attack_results_current_iteration.total_processing_time_seconds = iteration_elapsed.total_seconds()
attack_state.persistable.overall_result_data.attack_results.append(attack_results_current_iteration)
if attack_state.persistable.attack_params.write_output_every_iteration:
attack_state.write_output_files()
rollback_triggered = False
if attack_state.persistable.main_loop_iteration_number > 0:
rollback_message = ""
if attack_state.persistable.attack_params.rollback_on_loss_increase:
if (attack_results_current_iteration.loss - attack_state.persistable.attack_params.rollback_on_loss_threshold) > attack_state.persistable.best_loss_value:
if attack_state.persistable.attack_params.rollback_on_loss_threshold == 0.0:
rollback_message += f"The loss value for the current iteration ({attack_results_current_iteration.loss}) is greater than the best value achieved during this run ({attack_state.persistable.best_loss_value}). "
else:
rollback_message += f"The loss value for the current iteration ({attack_results_current_iteration.loss}) is greater than the allowed delta of {attack_state.persistable.attack_params.rollback_on_loss_threshold} from the best value achieved during this run ({attack_state.persistable.best_loss_value}). "
rollback_triggered = True
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Rollback not triggered by current loss value {attack_results_current_iteration.loss} versus current best value {attack_state.persistable.best_loss_value} and threshold {attack_state.persistable.attack_params.rollback_on_loss_threshold}.")
if attack_state.persistable.attack_params.rollback_on_jailbreak_count_decrease:
if (attack_results_current_iteration.jailbreak_detection_count + attack_state.persistable.attack_params.rollback_on_jailbreak_count_threshold) < attack_state.persistable.best_jailbreak_count:
if attack_state.persistable.attack_params.rollback_on_jailbreak_count_threshold == 0:
rollback_message += f"The jailbreak detection count for the current iteration ({attack_results_current_iteration.jailbreak_detection_count}) is less than for the best count achieved during this run ({attack_state.persistable.best_jailbreak_count}). "
else:
rollback_message += f"The jailbreak detection count for the current iteration ({attack_results_current_iteration.jailbreak_detection_count}) is less than the allowed delta of {attack_state.persistable.attack_params.rollback_on_jailbreak_count_threshold} from the best count achieved during this run ({attack_state.persistable.best_jailbreak_count}). "
rollback_triggered = True
else:
if attack_state.log_manager.get_lowest_log_level() <= logging.DEBUG:
logger.debug(f"Rollback not triggered by current jailbreak count {attack_results_current_iteration.jailbreak_detection_count} versus current best value {attack_state.persistable.best_jailbreak_count} and threshold {attack_state.persistable.attack_params.rollback_on_jailbreak_count_threshold}.")
# TKTK: if use of a threshold has allowed a score to drop below the last best value for x iterations, roll all the way back to the adversarial value that resulted in the current best value
# maybe use a tree model, with each branch from a node allowed to decrease 50% the amount of the previous branch, and too many failures to reach the value of the previous branch triggers a rollback to that branch
# That would allow some random exploration of various branches, at least allowing for the possibility of discovering a strong value within them, but never getting stuck for too long
if rollback_triggered:
#rollback_message += f"Rolling back to the last-known-good adversarial data {attack_state.persistable.last_known_good_adversarial_content.get_short_description()} for the next iteration instead of using this iteration's result {attack_state.persistable.current_adversarial_content.get_short_description()}."
rollback_message += f"Rolling back to the last-known-good adversarial data for the next iteration instead of using this iteration's result.\nThis iteration: '{attack_state.persistable.current_adversarial_content.get_short_description()}'\nLast-known-good: {attack_state.persistable.last_known_good_adversarial_content.get_short_description()}."
logger.info(rollback_message)
# add the rejected result to the list of tested results to avoid getting stuck in a loop
attack_state.persistable.tested_adversarial_content.append_if_new(attack_state.persistable.current_adversarial_content)
# roll back
#adversarial_content = attack_state.persistable.last_known_good_adversarial_content.copy()
#attack_state.persistable.current_adversarial_content = adversarial_content
attack_state.persistable.current_adversarial_content = attack_state.persistable.last_known_good_adversarial_content.copy()
# only update the "last-known-good" results if no rollback was triggered (for any reason)
# otherwise, if someone has multiple rollback options enabled, and only one of them is tripped, the other path will end up containing bad data
if not rollback_triggered:
rollback_notification_message = f"Updating last-known-good adversarial value from {attack_state.persistable.last_known_good_adversarial_content.get_short_description()} to {attack_state.persistable.current_adversarial_content.get_short_description()}."