001  (ns proomp.util.pipe-utils
002    (:require [proomp.config :as config]
003              [proomp.domain.prompt.prompt :as prom]
004              [proomp.domain.image.resolution :as res]
005              [proomp.domain.pipe.pipe-setup :as pipe-setup]
006              [cambium.core :as log]
007              [libpython-clj2.require :refer [require-python]]
008              [libpython-clj2.python :refer [py. py.-] :as py])
009    (:import (proomp.domain.image.resolution Resolution)
010             (proomp.domain.prompt.prompt Prompt)))
011  
012  (require-python 'torch '[torch.cuda :as cuda])
013  (require-python 'transformers)
014  (require-python '[diffusers :refer [StableDiffusionPipeline
015                                      StableDiffusionImg2ImgPipeline
016                                      StableDiffusionUpscalePipeline]])
017  
018  ;(require-python '[riffusion :refer [RiffusionPipeline]])
019  
020  (defonce device "cuda")
021  (defonce enable-attention-slicing? true)
022  (defonce use-memory-efficient-attention? true)
023  
024  (defn- clear-cuda-cache []
025    (cuda/empty_cache)
026    (log/trace "Cuda cache cleared."))
027  
028  (defn- send-to-device [pipe] (py. pipe "to" device))
029  (defn- enable-attention-slicing [pipe]
030    (py. pipe "enable_attention_slicing")
031    (log/debug "Attention slicing enabled."))
032  
033  (defn- enable-memory-efficient-attention [pipe]
034    (py. pipe "set_use_memory_efficient_attention_xformers" true)
035    (log/debug "Memory efficient attention enabled."))
036  
037  (defn- ->pipeline [type model-path]
038    (log/debug "Creating Pipeline.")
039    (clear-cuda-cache)
040    (let [pipe (py. type "from_pretrained" model-path
041                    :torch_dtype torch/float16)]
042      (send-to-device pipe)
043      (if enable-attention-slicing?
044        (enable-attention-slicing pipe))
045      (if use-memory-efficient-attention?
046        (enable-memory-efficient-attention pipe))
047      (log/trace {:pipe pipe})
048      pipe))
049  
050  (defn ->text-to-image-pipeline [] (->pipeline StableDiffusionPipeline config/model-path))
051  (defn ->image-to-image-pipeline [] (->pipeline StableDiffusionImg2ImgPipeline config/model-path))
052  (defn ->upscaler-pipeline [] (->pipeline StableDiffusionUpscalePipeline config/upscaler-model-path))
053  ;(defn ->riffusion-pipeline [] (->pipeline RiffusionPipeline config/riffusion-model-path))
054  
055  (defn- ->generator [seed] (py. (py/$c torch/Generator device) "manual_seed" seed))
056  
057  (defn- extract-first-image [result] (nth (py.- result :images) 0))
058  
059  (defn generate-image [pipe ^Prompt prompt seed]
060    (let [^Resolution resolution res/active-image-resolution]
061      (extract-first-image
062        (py/$c pipe (prom/full-prompt prompt)
063               :negative_prompt (prom/full-negative-prompt prompt)
064               :generator (->generator seed)
065               :height (:h resolution) :width (:w resolution)
066               :num_inference_steps (:iterations pipe-setup/image-pipe-setup)
067               :guidance_scale (:scale pipe-setup/image-pipe-setup)))))
068  
069  (defn generate-upscale [up-pipe ^Prompt prompt image]
070    (extract-first-image
071      (py/$c up-pipe (prom/full-prompt prompt)
072             :negative_prompt (prom/full-negative-prompt prompt)
073             :image image
074             :num_inference_steps 75
075             :guidance_scale 9.0
076             :noise_level 20)))
077  
078  (defn generate-i2i [pipe ^Prompt prompt seed init-image]
079    (extract-first-image
080      (py/$c pipe (prom/full-prompt prompt)
081             :negative_prompt (prom/full-negative-prompt prompt)
082             :init_image init-image
083             :strength (:noise pipe-setup/i2i-pipe-setup)
084             :generator (->generator seed)
085             :num_inference_steps (:iterations pipe-setup/i2i-pipe-setup)
086             :guidance_scale (:scale pipe-setup/i2i-pipe-setup))))