001 (ns proomp.util.pipe-utils
002 (:require [proomp.config :as config]
003 [proomp.domain.prompt.prompt :as prom]
004 [proomp.domain.image.resolution :as res]
005 [proomp.domain.pipe.pipe-setup :as pipe-setup]
006 [cambium.core :as log]
007 [libpython-clj2.require :refer [require-python]]
008 [libpython-clj2.python :refer [py. py.-] :as py])
009 (:import (proomp.domain.image.resolution Resolution)
010 (proomp.domain.prompt.prompt Prompt)))
011
012 (require-python 'torch '[torch.cuda :as cuda])
013 (require-python 'transformers)
014 (require-python '[diffusers :refer [StableDiffusionPipeline
015 StableDiffusionImg2ImgPipeline
016 StableDiffusionUpscalePipeline]])
017
018 ;(require-python '[riffusion :refer [RiffusionPipeline]])
019
020 (defonce device "cuda")
021 (defonce enable-attention-slicing? true)
022 (defonce use-memory-efficient-attention? true)
023
024 (defn- clear-cuda-cache []
025 (cuda/empty_cache)
026 (log/trace "Cuda cache cleared."))
027
028 (defn- send-to-device [pipe] (py. pipe "to" device))
029 (defn- enable-attention-slicing [pipe]
030 (py. pipe "enable_attention_slicing")
031 (log/debug "Attention slicing enabled."))
032
033 (defn- enable-memory-efficient-attention [pipe]
034 (py. pipe "set_use_memory_efficient_attention_xformers" true)
035 (log/debug "Memory efficient attention enabled."))
036
037 (defn- ->pipeline [type model-path]
038 (log/debug "Creating Pipeline.")
039 (clear-cuda-cache)
040 (let [pipe (py. type "from_pretrained" model-path
041 :torch_dtype torch/float16)]
042 (send-to-device pipe)
043 (if enable-attention-slicing?
044 (enable-attention-slicing pipe))
045 (if use-memory-efficient-attention?
046 (enable-memory-efficient-attention pipe))
047 (log/trace {:pipe pipe})
048 pipe))
049
050 (defn ->text-to-image-pipeline [] (->pipeline StableDiffusionPipeline config/model-path))
051 (defn ->image-to-image-pipeline [] (->pipeline StableDiffusionImg2ImgPipeline config/model-path))
052 (defn ->upscaler-pipeline [] (->pipeline StableDiffusionUpscalePipeline config/upscaler-model-path))
053 ;(defn ->riffusion-pipeline [] (->pipeline RiffusionPipeline config/riffusion-model-path))
054
055 (defn- ->generator [seed] (py. (py/$c torch/Generator device) "manual_seed" seed))
056
057 (defn- extract-first-image [result] (nth (py.- result :images) 0))
058
059 (defn generate-image [pipe ^Prompt prompt seed]
060 (let [^Resolution resolution res/active-image-resolution]
061 (extract-first-image
062 (py/$c pipe (prom/full-prompt prompt)
063 :negative_prompt (prom/full-negative-prompt prompt)
064 :generator (->generator seed)
065 :height (:h resolution) :width (:w resolution)
066 :num_inference_steps (:iterations pipe-setup/image-pipe-setup)
067 :guidance_scale (:scale pipe-setup/image-pipe-setup)))))
068
069 (defn generate-upscale [up-pipe ^Prompt prompt image]
070 (extract-first-image
071 (py/$c up-pipe (prom/full-prompt prompt)
072 :negative_prompt (prom/full-negative-prompt prompt)
073 :image image
074 :num_inference_steps 75
075 :guidance_scale 9.0
076 :noise_level 20)))
077
078 (defn generate-i2i [pipe ^Prompt prompt seed init-image]
079 (extract-first-image
080 (py/$c pipe (prom/full-prompt prompt)
081 :negative_prompt (prom/full-negative-prompt prompt)
082 :init_image init-image
083 :strength (:noise pipe-setup/i2i-pipe-setup)
084 :generator (->generator seed)
085 :num_inference_steps (:iterations pipe-setup/i2i-pipe-setup)
086 :guidance_scale (:scale pipe-setup/i2i-pipe-setup))))