diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 239c203162..34072d0f17 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -51,20 +51,21 @@ def from_pretrained(self, path_to_model): model = ExLlamaV2(config) - if shared.args.cache_8bit: - cache = ExLlamaV2Cache_8bit(model, lazy=True) - else: - cache = ExLlamaV2Cache(model, lazy=True) - - if shared.args.autosplit: - model.load_autosplit(cache) - else: + if not shared.args.autosplit: split = None if shared.args.gpu_split: split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] model.load(split) + if shared.args.cache_8bit: + cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit) + else: + cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit) + + if shared.args.autosplit: + model.load_autosplit(cache) + tokenizer = ExLlamaV2Tokenizer(config) generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index e5b35a44f3..1e21c2f1e0 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -36,24 +36,26 @@ class Exllamav2HF(PreTrainedModel): def __init__(self, config: ExLlamaV2Config): super().__init__(PretrainedConfig()) self.ex_config = config - self.ex_model = ExLlamaV2(config) self.loras = None self.generation_config = GenerationConfig() - if shared.args.cache_8bit: - self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=True) - else: - self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=True) + self.ex_model = ExLlamaV2(config) - if shared.args.autosplit: - self.ex_model.load_autosplit(self.ex_cache) - else: + if not shared.args.autosplit: split = None if shared.args.gpu_split: split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] self.ex_model.load(split) + if shared.args.cache_8bit: + self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit) + else: + self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit) + + if shared.args.autosplit: + self.ex_model.load_autosplit(self.ex_cache) + self.past_seq = None if shared.args.cfg_cache: if shared.args.cache_8bit: